Questions tagged [numpyro]

Repo: https://github.com/pyro-ppl/numpyro

Docs: https://num.pyro.ai/en/latest/index.html

15 questions
11
votes
1 answer

NumPyro vs Pyro: Why is former 100x faster and when should I use the latter?

From Pytorch-Pyro's website: We’re excited to announce the release of NumPyro, a NumPy-backed Pyro using JAX for automatic differentiation and JIT compilation, with over 100x speedup for HMC and NUTS! My questions: Where is the performance gain…
3
votes
1 answer

numpyro.render_model cannot be found

I am working with the Bayesian models in NumPyro. It is a relatively new library. I tried to visualize my model by following the Numpyro manual: http://num.pyro.ai/en/latest/tutorials/model_rendering.html My code returns: "module 'numpyro' has no…
2
votes
2 answers

JAX with JIT and custom differentiation

I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I…
1
vote
1 answer

Google Colab can't change Python version

Recently, all my notebooks have been updated to python 3.10, and I can't run my old code. I was using jax 0.2.17 and numpyro 0.7.1, but these no longer work with my version, so I tried changing to python 3.9. I used the following code: !sudo apt-get…
1
vote
0 answers

Numpyro: Error when using MCMC with a model that uses scan

I am trying to get the following model to work: def model_dynamic(self, hemp_size_t, values_t): # Unpack the values at time t t, actions_performed = values_t # Check if harvesting are performed at time step t harvest =…
Theophile Champion
  • 453
  • 1
  • 4
  • 20
1
vote
1 answer

GARCH models in Numpyro

I was curious how best to implement GARCH models in numpyro. I tried reading https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html but found it generally unclear (the model notation and variable names are not easy to map, the model…
Shffl
  • 396
  • 3
  • 18
1
vote
0 answers

Pyro for differential equations

Is it possible to infer differential equation parameters using pyro? I found an example with numpyro I was wondering if this is possible with pyro as well?
q than a
  • 45
  • 5
1
vote
1 answer

A JAX custom VJP function for multiple input variable does not work for NumPyro/HMC-NUTS

I am trying to use a custom VJP (vector-Jacobian product) function as a model for a HMC-NUTS in numpyro. I was able to make a single variable function that works for HMC-NUTS as follows: import jax.numpy as jnp from jax import…
0
votes
0 answers

How do I set a different prior distribution for each Numpyro Component

I'm building mixture model in Numpyro, and would like to set each component with a different prior. So far I have this where n == the number of components in the mixture. def model(data, n): weights = numpyro.sample('weights',…
caceves
  • 23
  • 6
0
votes
0 answers

Error associated with using NumPyro to create a linear regression model

I'm using Numpyro to create a simple linear regression model consisting of two variables, the aim is to obtain a similar graph to https://num.pyro.ai/en/latest/tutorials/bayesian_regression.html (3rd graph). I have used numpyro to generate 2000…
0
votes
0 answers

NumPyro Value Error - Normal distribution got invalid loc parameter

I'm trying to code an MCMC for the lorenz system (based on the following example of the predator-prey model - https://num.pyro.ai/en/stable/examples/ode.html). I used the structure of the example program and simply replaced the model, however, I am…
0
votes
0 answers

How to write an arbitrary number of models (functions) in NumPyro?

I have 5 NumPyro models of the form: def patient_1(): disease = numpyro.sample("disease",dist.Bernoulli(0.4)) treatment = numpyro.sample("treatment",dist.Bernoulli(0.7)) list_variables=[disease,treatment] …
0
votes
0 answers

Pyro - Extracting Individual Kernels from GPRegression

I'm trying to understand how to extract the individual kernels from a GPRegression object similar to how the Birthday problem is approached in BDA3. Using the example from Pyro's documentation: X = torch.linspace(-5, 5, 100) y = torch.sin(X * 8) + 2…
JBlaz
  • 81
  • 1
  • 4
0
votes
1 answer

How to predict time series with limited data

I have a dataset with four columns: date, category, product, rate(%). I would like to be able to forecast the rate for every product in my data. The major issue I'm having is that because products constantly come in an out of production, certain…
0
votes
0 answers

Gaussian Process Kernels for time series : one kernel for t<0 and one for t>0, is it possible?

I have small time series, (t_i, Y_i)_{i0, and from one series to the other the index "i" that cuts the two processes time zone,…
Jean-Eric
  • 372
  • 2
  • 14