1

I've read some answers on this question here and here, however I'm still a bit puzzled by tf.Variable being and/or not being a tf.Tensor. The linked answers deal with a mutability of tf.Variable and mentioning that tf.Variables maintains their states (when instantiated with default parameter trainable=True).

What makes me still a bit confused is a test case I came across when writing simple unit tests using tf.test.TestCase

Consider the following code snippet. We have a simple class called Foo which has only one property, a tf.Variable initialized to w:

import tensorflow as tf
import numpy as np

class Foo:
    def __init__(self, w):
        self.w = tf.Variable(w)

Now, let's say you want to test that the instance of Foo has w initialized with tensor of the same dimension as passed in via w. The simplest test case could be written as follows:

import tensorflow as tf
import numpy as np
from foo import Foo

class TestFoo(tf.test.TestCase):
    def test_init(self):
        w = np.random.rand(3,2)
        foo = Foo(w)
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            self.assertShapeEqual(w, foo.w)

if __name__ == '__main__':
      tf.test.main()

Now when you run the test you'll get the following error:

======================================================================
ERROR: test_init (__main__.TestFoo)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_foo.py", line 12, in test_init
    self.assertShapeEqual(w, foo.w)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/test_util.py", line 1100, in assertShapeEqual
    raise TypeError("tf_tensor must be a Tensor")
TypeError: tf_tensor must be a Tensor

----------------------------------------------------------------------
Ran 2 tests in 0.027s

FAILED (errors=1)

You can "get around" this unit test error by doing something like this (i.e. note assertShapeEqual was replaced with assertEqual):

self.assertEqual(list(w.shape), foo.w.get_shape().as_list())

What I'm interested in, though, is the tf.Variable vs tf.Tensor relationship. What the test error seems to be suggesting is that foo.w is NOT a tf.Tensor, meaning you probably can't use tf.Tensor API on it. Consider, however, the following interactive python session:

$ python3
Python 3.6.3 (default, Oct  4 2017, 06:09:15)
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>> w = np.random.rand(3,2)
>>> var = tf.Variable(w)
>>> var.get_shape().as_list()
[3, 2]
>>> list(w.shape)
[3, 2]
>>>

In the session above, we create a variable and run the get_shape() method on it to retrieve its shape dimensions. Now, get_shape() method is a tf.Tensor API method as you can see here.

So to get back to my question, what parts of tf.Tensor API does tf.Variable implement. If the answer is ALL of them, why does the above test case fail?

self.assertShapeEqual(w, foo.w)

with

raise TypeError("tf_tensor must be a Tensor")

I'm pretty sure I'm missing something fundamental here or maybe it's a bug in assertShapeEqual ? I would appreciate if someone could shed some light on this.

I'm using following version of tensorflow on macOS with python3:

tensorflow (1.4.1)
milosgajdos
  • 851
  • 1
  • 14
  • 32

1 Answers1

0

That testing utility function is checking whether a variable implements tf.Tensor

>>> import tensorflow as tf
>>> v = tf.Variable('v')
>>> v
<tf.Variable 'Variable:0' shape=() dtype=string_ref>
>>> isinstance(v, tf.Tensor)
False

The answer appears to be 'no'.

Update:

According to the documentation that is correct:

https://www.tensorflow.org/programmers_guide/variables

Unlike tf.Tensor objects, a tf.Variable exists outside the context of a single session.run call.

Although:

A tf.Variable represents a tensor whose value can be changed by running ops on it.

(Not quite sure what 'represents a tensor' means - sounds like a design 'feature')

de1
  • 2,986
  • 1
  • 15
  • 32
  • Nice! I have to admit I haven't had a look at the actual implementation of `assertShapeEqual`. Still the existence of the confusion is unsettling. Some things are tensors, some are sort of, some part os tensor API can be used and some can't. So the issue might really be with the actual unit testing function - clearly I'm just not understanding the documentation well and thus misusing it. – milosgajdos Jan 09 '18 at 17:12
  • I have to agree with you there. It is confusing. (I was sure it was a tensor) – de1 Jan 09 '18 at 17:15