I've read some answers on this question here and here, however I'm still a bit puzzled by tf.Variable
being and/or not being a tf.Tensor
.
The linked answers deal with a mutability of tf.Variable
and mentioning that tf.Variable
s maintains their states (when instantiated with default parameter trainable=True
).
What makes me still a bit confused is a test case I came across when writing simple unit tests using tf.test.TestCase
Consider the following code snippet. We have a simple class called Foo
which has only one property, a tf.Variable
initialized to w
:
import tensorflow as tf
import numpy as np
class Foo:
def __init__(self, w):
self.w = tf.Variable(w)
Now, let's say you want to test that the instance of Foo
has w initialized with tensor of the same dimension as passed in via w
. The simplest test case could be written as follows:
import tensorflow as tf
import numpy as np
from foo import Foo
class TestFoo(tf.test.TestCase):
def test_init(self):
w = np.random.rand(3,2)
foo = Foo(w)
init = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init)
self.assertShapeEqual(w, foo.w)
if __name__ == '__main__':
tf.test.main()
Now when you run the test you'll get the following error:
======================================================================
ERROR: test_init (__main__.TestFoo)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test_foo.py", line 12, in test_init
self.assertShapeEqual(w, foo.w)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/test_util.py", line 1100, in assertShapeEqual
raise TypeError("tf_tensor must be a Tensor")
TypeError: tf_tensor must be a Tensor
----------------------------------------------------------------------
Ran 2 tests in 0.027s
FAILED (errors=1)
You can "get around" this unit test error by doing something like this (i.e. note assertShapeEqual
was replaced with assertEqual
):
self.assertEqual(list(w.shape), foo.w.get_shape().as_list())
What I'm interested in, though, is the tf.Variable
vs tf.Tensor
relationship.
What the test error seems to be suggesting is that foo.w
is NOT a tf.Tensor
, meaning you probably can't use tf.Tensor
API on it. Consider, however, the following interactive python session:
$ python3
Python 3.6.3 (default, Oct 4 2017, 06:09:15)
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>> w = np.random.rand(3,2)
>>> var = tf.Variable(w)
>>> var.get_shape().as_list()
[3, 2]
>>> list(w.shape)
[3, 2]
>>>
In the session above, we create a variable and run the get_shape()
method on it to retrieve its shape dimensions. Now, get_shape()
method is a tf.Tensor
API method as you can see here.
So to get back to my question, what parts of tf.Tensor
API does tf.Variable
implement. If the answer is ALL of them, why does the above test case fail?
self.assertShapeEqual(w, foo.w)
with
raise TypeError("tf_tensor must be a Tensor")
I'm pretty sure I'm missing something fundamental here or maybe it's a bug in assertShapeEqual ? I would appreciate if someone could shed some light on this.
I'm using following version of tensorflow
on macOS with python3
:
tensorflow (1.4.1)