from pymc import Stochastic, Deterministic, Node, StepMethod
from numpy import ma, random, where
from numpy.random import random
from copy import deepcopy

class OneAtATimeRJ(StepMethod):
    """
    S = OneAtATimeRJ(self, stochs, indicator, p, rp, g, q, rq, inv_q, Jacobian, **kwargs)

    OneAtATimeRJ can control single indicatored-array-valued stochs. The indicator
    indicates which stochs are currently 'in the model;' if
    stoch.value.indicator[index] = True, that index is currently being excluded.

    indicatored-array-valued stochs and their children should understand how to
    cope with indicatored arrays when evaluating their logpabilities.

    The prior for the indicatored-array-valued stoch may depend explicitly on the
    indicator.

    The dtrm arguments are, in notation similar to that of Waagepetersen et al.,

        def rp(indicator):
            Draws a new value for the indicator.

        def p(indicator):
            Returns the probability of jumping to indicator.value from
            indicator.last_value inder rp.

        def rq(indicator):
            Draws a value for the auxiliary RV's u given indicator.value (proposed),
            indicator.last_value (current), and the value of the stochs.

        def g(indicator, u, **stochs):
            g generates a new value for the stochs given indicator.last_value (current) and
            indicator.value (proposed), the current value of the stoch, and the auxiliary RV's
            u which are generated by rq.

        def q(indicator, u):
            q computes the density of the auxiliary RV's u given the last_values of the stochs,
            and the current and proposed indicator values.

        def inv_q(indicator):
            inv_q computes the density, under the distribution defined by rq, of the value of u
            which would have been required to propose last_value of stochs from current_value.

        def Jacobian(indicator, u, **stochs):
            Returns the log Jacobian of the jump transformation (u,z) -> (u',z'), where u'
            is the value of the auxiliary RV u which would have been required to propose the
            current value z from the proposed value z_p.

        (u' is never used directly by this class, just implicitly in inv_q and Jacobian
        in case the transformation (z',z) -> u' isn't invertible. I think it's OK as far
        as math goes to use inverse images.)

    """
    def __init__(self, stochs, indicator, p, rp, g, q, rq, inv_q, Jacobian):

        StepMethod.__init__(self, nodes = stochs)

        self.g = g
        self.q = q
        self.rq = rq
        self.p = p
        self.rp = rp
        self.inv_q = inv_q
        self.Jacobian = Jacobian

        self.stoch_dict = {}
        for stoch in stochs:
            self.stoch_dict[stoch.__name__] = stoch

        self.indicator = indicator


    def propose(self):
        """
        Sample a new indicator and value for the stoch.
        """
        self.rp(self.indicator)
        self._u = self.rq(self.indicator)
        self.g(self.indicator, self._u, **self.stoch_dict)



    def step(self):
        # logpability and loglike for stoch's current value:
        logp = sum([stoch.logp for stoch in self.stochs]) + self.indicator.logp
        loglike = self.loglike

        # Sample a candidate value for the value and indicator of the stoch.
        self.propose()

        # logpability and loglike for stoch's proposed value:
        logp_p = sum([stoch.logp for stoch in self.stochs]) + self.indicator.logp

        # Skip the rest if a bad value is proposed
        if logp_p == -Inf:
            for stoch in self.stochs: stoch.revert()
            return

        loglike_p = self.loglike

        # test:
        test_val =  logp_p + loglike_p - logp - loglike
        test_val += self.inv_q(self.indicator)
        test_val += self.q(self.indicator,self._u)

        if self.Jacobian is not None:
            test_val += self.Jacobian(self.indicator,self._u,**self.stoch_dict)

        if log(random()) > test_val:
            for stoch in self.stochs:
                stoch.revert


    def tune(self):
        pass



