6

I tried to port simple survival model from here (the first one in introduction) form PyMC 2 to PyMC 3. However, I didn't find any equivalent to "observed" decorator and my attempt to write a new distribution failed. Could someone provide an example how is this done in PyMC 3?

Marigold
  • 1,619
  • 1
  • 15
  • 17

1 Answers1

6

This is a tricky port, and requires three new concepts:

  1. Use of the theano tensor
  2. Use of the DensityDist
  3. Passing a dict as observed

This code provides the equivalent model as the PyMC2 version you linked to above:

import pymc3 as pm
from pymc.examples import melanoma_data as data
import theano.tensor as t

times = data.t # not to be confused with the theano tensor t!
failure = (data.censored==0).astype(int)

with pm.Model() as model:

    beta0 = pm.Normal('beta0', mu=0.0, tau=0.0001)
    beta1 = pm.Normal('beta1', mu=0.0, tau=0.0001)
    lam = t.exp(beta0 + beta1*data.treat)

    def survival_like(failure, value):
        return t.sum(failure * t.log(lam) - lam * value)

    survive = pm.DensityDist('survive', survival_like,
                        observed={'failure': failure, 'value': times})

with model:

    start = pm.find_MAP()
    step = pm.NUTS(scaling=start)
    trace = pm.sample(10000, step=step, start=start)

pm.traceplot(trace);

Output as follows:

enter image description here

inversion
  • 1,304
  • 11
  • 11
  • When passing `observed={...}`, how are the arguments passed to `survival_like`? Do the arguments have to be in alphabetical order? Thanks! – Stefan Novak Oct 06 '15 at 01:18
  • I don't believe it matters. But you could do a simple test to confirm. – inversion Oct 06 '15 at 13:08
  • Thanks! I was able to dive into the PyMC codebase to see that `logp` is called with `**data`, so the values from the dict will be passed to the correct argument, regardless of order. https://github.com/pymc-devs/pymc3/blob/master/pymc3/model.py#L535 – Stefan Novak Oct 06 '15 at 16:47
  • I think the line: ```return t.sum(failure * t.log(lam) - lam * value)``` should be: ```return t.sum(failure * (t.log(lam) - lam * value))``` – Yetti May 21 '19 at 00:57