I am trying to understand why I am getting a ValueError when I try to replace SVGP with VGP in the heteroscedastic regression example (https://gpflow.readthedocs.io/en/develop/notebooks/advanced/heteroskedastic.html) in GPflow.
Here are the changes I made:
- model = gpf.models.VGP(...)
- loss_fn = model.training_loss_closure() instead of loss_fn = model.training_loss_closure(data)
The kernel and likelihood are the same as the example.
data = (X, Y)
model = gpf.models.VGP(
data = data,
kernel=kernel,
likelihood=likelihood,
#inducing_variable=inducing_variable,
num_latent_gps=likelihood.latent_dim,
)
loss_fn = model.training_loss_closure()
gpf.utilities.set_trainable(model.q_mu, False)
gpf.utilities.set_trainable(model.q_sqrt, False)
variational_vars = [(model.q_mu, model.q_sqrt)]
natgrad_opt = gpf.optimizers.NaturalGradient(gamma=0.1)
adam_vars = model.trainable_variables
adam_opt = tf.optimizers.Adam(0.01)
@tf.function
def optimisation_step():
natgrad_opt.minimize(loss_fn, variational_vars)
adam_opt.minimize(loss_fn, adam_vars)
epochs = 100
for epoch in range(0, epochs):
optimisation_step()
The optimization step gives me this error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_406484/3662007586.py in <module>
3
4 for epoch in range(1, epochs + 1):
----> 5 optimisation_step()
6
7
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
887
888 with OptionalXlaContext(self._jit_compile):
--> 889 result = self._call(*args, **kwds)
890
891 new_tracing_count = self.experimental_get_tracing_count()
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
931 # This is the first call of __call__, so we have to initialize.
932 initializers = []
--> 933 self._initialize(args, kwds, add_initializers_to=initializers)
934 finally:
935 # At this point we know that the initialization is complete (or less
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
761 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
762 self._concrete_stateful_fn = (
--> 763 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
764 *args, **kwds))
765
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3048 args, kwargs = None, None
3049 with self._lock:
-> 3050 graph_function, _ = self._maybe_define_function(args, kwargs)
3051 return graph_function
3052
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
984 except Exception as e: # pylint:disable=broad-except
985 if hasattr(e, "ag_error_metadata"):
--> 986 raise e.ag_error_metadata.to_exception(e)
987 else:
988 raise
ValueError: in user code:
ValueError: Dimensions must be equal, but are 2 and 1001 for '{{node add_2}} = AddV2[T=DT_DOUBLE](diag, mul_1)' with input shapes: [1001,2,2], [1001,1001].
Is this a bug or are the likelihood and model incompatible or am I missing something? A workaround is to make the inducing variables = the training data and use SVGP but that makes the training incredibly slow...