There is indeed a similar debugging tool in tensorflow. See tf.debugging.check_numerics
.
This can be used to track the tensors that produce inf
or nan
values during training. As soon as such value is found, tensorflow produces an InvalidArgumentError
.
tf.debugging.check_numerics(LayerN, "LayerN is producing nans!")
If the tensor LayerN
has nans, you would get an error like that:
Traceback (most recent call last):
File "trainer.py", line 506, in <module>
worker.train_model()
File "trainer.py", line 211, in train_model
l, tmae = train_step(*batch)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 855, in _call
return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: LayerN is producing nans! : Tensor had NaN values