2

How to check if a tf.Tensor is mutable?

I want to assert the arguments of a function have the correct types.

A tf.Tensor can be mutable:

import tensorflow as tf
import numpy as np
x = tf.get_variable('x', shape=(2,), dtype=np.float32)
print(x[1])  # x[1] is a tf.Tensor
tf.assign(x[1], 1.0)
R zu
  • 2,034
  • 12
  • 30

2 Answers2

2

This is not part of the public API, but looking at how tf.assign is implemented, I think you can just do:

import tensorflow as tf

def is_assignable(x):
    return x.dtype._is_ref_dtype or (isinstance(x, tf.Tensor) and hasattr(x, 'assign'))
jdehesa
  • 58,456
  • 7
  • 77
  • 121
0

You can check their dtype attributes e.g. assert my_tensor.dtype == tf.float32.

Tensors are immutable outside of variables: they describe relationships between quantities. Data types will not change unless a type cast operation is added to the graph, adding an edge. If a value is passed to a tensor with a type that is different to the expected type, e.g. when loading data into a pipeline, an error is raised. You can check this by assigning a tensor with an incorrect type -- you will get an error.

Try this code

import tensorflow as tf
x = tf.get_variable('x', shape=(2,), dtype=tf.float32)
tf.assign(x[1], tf.ones(shape=(2,), dtype=tf.int32))

You should get an error to the effect of "TypeError: Input 'value' of 'StridedSliceAssign' Op has type int32 that does not match type float32 of argument 'ref'."

Jeffrey Ede
  • 601
  • 5
  • 11
  • In my example, `x.dtype` is `tf.float32_ref` and `x[1].dtype` is `tf.float32`. Indexing can change `dtype`. Both `x` and `x[1]` are mutable. – R zu Jul 18 '19 at 17:12
  • If `x.dtype` is `tf.float32_ref`, `x.dtype.base_dtype` is `tf.float32`. The _ref just indicates that the node contains a mutable variable. Indexing does not change base data type. – Jeffrey Ede Jul 19 '19 at 07:34