0

I'm using Numpyro to create a simple linear regression model consisting of two variables, the aim is to obtain a similar graph to https://num.pyro.ai/en/latest/tutorials/bayesian_regression.html (3rd graph).

I have used numpyro to generate 2000 samples of data and all of the code below runs as expected.

def model(data=None):
    mu = numpyro.sample("mu", dist.Normal(0.0, 0.2))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, data=data.AgeUncScaled.values
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

Now, when moving onto the final step:

def plot_regression(x, y_mean, y_hpdi):
    # Sort values for plotting by x axis
    idx = jnp.argsort(x)
    age = x[idx]
    mean = y_mean[idx]
    hpdi = y_hpdi[:, idx]
    age_unc = dataset.AgeUncScaled.values[idx]

    # Plot
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
    ax.plot(age, mean)
    ax.plot(age, age_unc, "o")
    ax.fill_between(age, hpdi[0], hpdi[1], alpha=0.3, interpolate=True)
    return ax

posterior_mu = (
    jnp.expand_dims(samples_1["mu"], -1)
  
)

mean_mu = jnp.mean(posterior_mu, axis=0)
hpdi_mu = hpdi(posterior_mu, 0.9)
ax = plot_regression(dataset.AgeScaled.values, mean_mu, hpdi_mu)

I run into all kinds of errors (including type & index) ... not too sure what is going on and would be grateful for any help :)

  • Welcome to Stackoverflow! Without seeing the errors or being able to execute the code, it's going to be hard for anyone to answer this question. I'd suggest editing your question to add a [Minimal Reproducible Example](https://stackoverflow.com/help/minimal-reproducible-example), including all imports and definitions required to run your code and see the outputs that you're seeing. – jakevdp Jul 06 '23 at 13:09

0 Answers0