#include "AppHdr.h"

#include "random-var.h"

#include "random.h"

random_var::random_var(int c)
    : start(c), end(c+1)
{
    weights.push_back(1);
    init();
}

random_var::random_var(int s, int e, weight_func w)
    : start(s), end(e)
{
    init_weights(w);
    init();
}

random_var::random_var(int s, int e, std::vector<int> ws)
    : start(s), end(e), weights(ws)
{
    ASSERT(weights.size() == static_cast<unsigned int>(end - start));
    init();
}

const random_var& random_var::operator=(const random_var& other)
{
    start = other.start;
    end = other.end;
    total = other.total;
    weights = other.weights;
    return (*this);
}

int random_var::weight(int val) const
{
    if (val < start || val >= end)
        return (0);
    return (weights[val - start]);
}

void random_var::init_weights(weight_func w)
{
    ASSERT(weights.empty());
    for (int v = start; v < end; ++v)
        weights.push_back(w ? (*w)(v) : 1);
}

void random_var::init()
{
    total = 0;
    for (int v = start; v < end; ++v)
        total += weight(v);
    ASSERT(total > 0);
    ASSERT(weight(start) > 0 && weight(end - 1) > 0);
}

int random_var::roll2val(int r) const
{
    ASSERT(0 <= r && r < total);
    int v = start;
    int w = weight(v);
    while (r >= w)
    {
        v++;
        w += weight(v);
    }
    return (v);
}

int random_var::roll() const
{
    return (roll2val(random2(total)));
}

int random_var::max() const
{
    return (end - 1);
}

int random_var::min() const
{
    return (start);
}

double random_var::expected() const
{
    double ev = 0;
    for (int i = start; i < end; ++i)
        ev += i * weight(i) / (double)total;
    return (ev);
}


//////////////////////////////////

random_var constant(int n)
{
    return (random_var(n));
}

random_var operator+(const random_var& x, const random_var& y)
{
    const int start = x.min() + y.min();
    const int end = x.max() + y.max() + 1;
    std::vector<int> weights(end - start, 0);

    for (int vx = x.min(); vx <= x.max(); ++vx)
        for (int vy = y.min(); vy <= y.max(); ++vy)
            weights[vx + vy - start] += x.weight(vx) * y.weight(vy);

    return (random_var(start, end, weights));
}

random_var negate(const random_var& x)
{
    const int start = -x.max();
    const int end = -x.min() + 1;
    std::vector<int> weights(end - start, 0);

    for (int v = x.min(); v <= x.max(); ++v)
        weights[-v - start] = x.weight(v);

    return (random_var(start, end, weights));
}

random_var operator-(const random_var& x, const random_var& y)
{
    return (x + negate(y));
}

const random_var& operator+=(random_var& x, const random_var& y)
{
    x = x + y;
    return (x);
}

const random_var& operator-=(random_var& x, const random_var& y)
{
    x = x - y;
    return (x);
}

random_var operator/(const random_var& x, int d)
{
    const int start = x.min() / d;
    const int end = x.max() / d + 1;
    std::vector<int> weights(end - start, 0);

    for (int v = x.min(); v <= x.max(); ++v)
        weights[v / d - start] += x.weight(v);

    return (random_var(start, end, weights));
}

random_var div_rand_round(const random_var& x, int d)
{
    const int start = x.min() / d;
    const int end = (x.max() + d - 1) / d + 1;
    std::vector<int> weights(end - start, 0);

    for (int v = x.min(); v <= x.max(); ++v)
    {
        int rem = v % d;
        weights[v / d - start] += x.weight(v) * (d - rem);
        if (rem > 0)
            weights[v / d + 1 - start] += x.weight(v) * rem;
    }

    return (random_var(start, end, weights));
}

random_var rv::max(const random_var& x, const random_var& y)
{
    const int start = std::max(x.min(), y.min());
    const int end = std::max(x.max(), y.max()) + 1;
    std::vector<int> weights(end - start, 0);

    for (int vx = x.min(); vx <= x.max(); ++vx)
        for (int vy = y.min(); vy <= y.max(); ++vy)
            weights[std::max(vx, vy) - start] += x.weight(vx) * y.weight(vy);

    return (random_var(start, end, weights));
}

random_var rv::min(const random_var& x, const random_var& y)
{
    const int start = std::min(x.min(), y.min());
    const int end = std::min(x.max(), y.max()) + 1;
    std::vector<int> weights(end - start, 0);

    for (int vx = x.min(); vx <= x.max(); ++vx)
        for (int vy = y.min(); vy <= y.max(); ++vy)
            weights[std::min(vx, vy) - start] += x.weight(vx) * y.weight(vy);

    return (random_var(start, end, weights));
}

random_var rv::roll_dice(int d, int n)
{
    if (n <= 0)
        return (constant(0));
    random_var x = constant(0);
    for (int i = 0; i < d; ++i)
        x += random_var(1, n+1);
    return (x);
}

random_var rv::random2(int n)
{
    return (random_var(0, std::max(n, 1)));
}
