0

I try to compare different priors, e.g., N(0,1), N(0,10),..., using a Bayesian neural network. Unfortunately, I cannot find a solution for specifying the MVN.

Do you have any clues?

N = len(y_train)*0.8 # number of samples
m = 2 # number of outcomes
# prior ~ N(0,1)
def normal_sp(params): 
  return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))# both parameters are learnable

def build_and_compile_bnn_model(norm,learning_rate):
  kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (X_train.shape[0] * 0.8) #validation_split=0.2
  bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (X_train.shape[0] * 0.8) #validation_split=0.2
  inputs = Input(shape=(X.shape[1],))
  hidden = norm(inputs)
  hidden = tfp.layers.DenseFlipout(int(np.round(np.sqrt((m+2)*N)+2*np.sqrt(N/(m+2)),0)),bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                            bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                            kernel_divergence_fn=kernel_divergence_fn,
                            bias_divergence_fn=bias_divergence_fn,activation="relu")(hidden)
  hidden = tfp.layers.DenseFlipout(int(np.round(m*np.sqrt(N/(m+2)))),bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                            bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                            kernel_divergence_fn=kernel_divergence_fn,
                            bias_divergence_fn=bias_divergence_fn,activation="relu")(hidden)
  params = tfp.layers.DenseFlipout(2,bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                            bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                            kernel_divergence_fn=kernel_divergence_fn,
                            bias_divergence_fn=bias_divergence_fn)(hidden)
  dist = tfp.layers.DistributionLambda(normal_sp)(params) 
  model = Model(inputs=inputs, outputs=dist)
  model.compile(Adam(learning_rate=learning_rate), loss=NLL) 
  model_params = Model(inputs=inputs, outputs=params)
  return model
desertnaut
  • 57,590
  • 26
  • 140
  • 166
feviro
  • 9
  • 1

0 Answers0