6

I'm modifying an existing model using RJAGS. I'd like to run chains in parallel, and occasionally check the Gelman-Rubin convergence diagnostic to see if I need to keep running. The problem is, if I need to resume running based on the diagnostic value, the recompiled chains restart from the first initialized prior values and not the position in parameter space where the chain stopped. If I do not recompile the model, RJAGS complains. Is there a way to store the positions of the chains when they stop so I can re-initialize from where I left off? Here I'll give a very simplified example.

example1.bug:

model {
  for (i in 1:N) {
      x[i] ~ dnorm(mu,tau)
  }
  mu ~ dnorm(0,0.0001)
  tau <- pow(sigma,-2)
  sigma ~ dunif(0,100)
}

parallel_test.R:

#Make some fake data
N <- 1000
x <- rnorm(N,0,5)
write.table(x,
        file='example1.data',
        row.names=FALSE,
        col.names=FALSE)

library('rjags')
library('doParallel')
library('random')

nchains <- 4
c1 <- makeCluster(nchains)
registerDoParallel(c1)

jags=list()
for (i in 1:getDoParWorkers()){
  jags[[i]] <- jags.model('example1.bug',
                          data=list('x'=x,'N'=N))
}

# Function to combine multiple mcmc lists into a single one
mcmc.combine <- function( ... ){
  return( as.mcmc.list( sapply( list( ... ),mcmc ) ) )
}

#Start with some burn-in
jags.parsamples <- foreach( i=1:getDoParWorkers(),
                           .inorder=FALSE,
                           .packages=c('rjags','random'),
                           .combine='mcmc.combine',
                           .multicombine=TRUE) %dopar%
{
  jags[[i]]$recompile()

  update(jags[[i]],100)
  jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)

  return(jags.samples)
}   

#Check the diagnostic output
print(gelman.diag(jags.parsamples[,'mu']))

counter <- 0

#my model doesn't converge so quickly, so let's simulate doing
#this updating 5 times:
#while(gelman.diag(jags.parsamples[,'mu'])[[1]][[2]] > 1.04)
while(counter < 5)
{
counter <- counter + 1
jags.parsamples <- foreach(i=1:getDoParWorkers(),
                             .inorder=FALSE,
                             .packages=c('rjags','random'),
                             .combine='mcmc.combine',
                             .multicombine=TRUE) %dopar%
  {
    #Here I lose the progress I've made
    jags[[i]]$recompile()
    jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
    return(jags.samples)
  }
}

print(gelman.diag(jags.parsamples[,'mu']))
print(summary(jags.parsamples))
stopCluster(c1)

In the output, I see:

Iterations = 1001:2000

where I know there should be > 5000 iterations. (cross-posted to stats.stackexchange.com, which may be the more appropriate venue)

sjc
  • 1,117
  • 3
  • 19
  • 28
  • I'm finding `R2jags::jags.parallel` works great for the parallelization part. Have the other parts of this (checking for convergence, picking up where the chain left off if not met) still the same, or are there new tools to do this? – Michael Roswell Jul 17 '19 at 02:44

1 Answers1

5

Every time your JAGS model runs on the worker nodes the coda samples are returned but the state of the model is lost. So next time it recompiles, it restarts from the beginning, as you are seeing. To get around this you need to get and return the state of the model in your function (on the worker nodes) like so:

 endstate <- jags[[i]]$state(internal=TRUE)

Then you need to pass this back to the worker node and re-generate the model within the worker function using jags.model() with inits=endstate (for the appropriate chain).

I would actually recommend looking at the runjags package that does all this for you. For example:

library('runjags')
parsamples <- run.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), sample=100, method='rjparallel')
summary(parsamples)
newparsamples <- extend.jags(parsamples, sample=100)
summary(parsamples)
# etc

Or even:

parsamples <- autorun.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), method='rjparallel')

Version 2 of runjags will hopefully be uploaded to CRAN soon, but for now you can download binaries from: https://sourceforge.net/projects/runjags/files/runjags/

Matt

Matt Denwood
  • 2,537
  • 1
  • 12
  • 16
  • I couldn't quite get it to work passing init=endstate[[i]] on subsequent passes (I wound up seeing identical traces each time), but my simple example did work very well when I ran with runjags. It also appears that it will fit quite painlessly into my much larger and more complicated model. Is it acceptable to trust runjags v1.2.1 for now? – sjc Apr 07 '15 at 18:06
  • I only remember fixing one bug (which isn't relevant to your model) relating to the rjparallel method with version 2, so it should be fine. The main advantage to upgrading is improved plotting and summary facilities which are additional/improved features rather than bug fixes. – Matt Denwood Apr 07 '15 at 20:42
  • Getting an error on `extend.jags` after successfully running a model with `run.jags`: `Error: unused argument(s) 'samples' (no unambiguous match in the 'extend.jags' or 'add.summary' functions)` – colin Jun 22 '16 at 11:50
  • Are you sure it's not something simple like 'samples'->'sample' ? – sjc Jun 27 '16 at 18:13
  • 1
    Yes the argument should be 'sample' - I fixed the typo in the answer and left a comment several days ago, but it seems the comment didn't post for some reason - sorry! – Matt Denwood Jun 28 '16 at 05:10