0

Look this example.

import tensorflow as tf

tf.reset_default_graph()
LENGTH = 25
M_list = []
for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[1], initializer=tf.constant_initializer(i)))

choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
case_set = [(tf.equal(choose_mat[i], 1), lambda: M_list[i]) for i in range(LENGTH)]
M = tf.case(case_set)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    CM1 = [0] * LENGTH
    CM1[0] = 1
    CM2 = [0] * LENGTH
    CM2[1] = 1

    m1 = sess.run(M, feed_dict={choose_mat: CM1})
    m2 = sess.run(M, feed_dict={choose_mat: CM2})
    print(m1) # [24.]
    print(m2) # [24.]

    m1_ = sess.run(M_list[0])
    m2_ = sess.run(M_list[1])
    print(m1_) # [0.]
    print(m2_) # [1.]

We expect m1, m2 is 0, 1 but we got 24. And the result of M_list is right, just like m1_ and m2_, it's strange.

Although I have fixed this bug(see my answer), I still have a question, I don't know why this code will cause closure, case_set is not in any function, dose anyone know why this is closure?

0811张庆昊
  • 538
  • 4
  • 16

1 Answers1

0

Infact, this bug is not caused by tensorflow, the real reason is python's closure. see this link So this code will get expected result.

import tensorflow as tf

tf.reset_default_graph()
LENGTH = 25
M_list = []
for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[1], initializer=tf.constant_initializer(i)))

choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
case_set = [(tf.equal(choose_mat[i], 1), lambda i=i: M_list[i]) for i in range(LENGTH)]
M = tf.case(case_set)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    CM1 = [0] * LENGTH
    CM1[0] = 1
    CM2 = [0] * LENGTH
    CM2[1] = 1

    m1 = sess.run(M, feed_dict={choose_mat: CM1})
    m2 = sess.run(M, feed_dict={choose_mat: CM2})
    print(m1) # [0.]
    print(m2) # [1.]

Although I have fixed this bug, I still don't know why this code will cause closure, case_set is not in any function, dose anyone know why this is closure?

0811张庆昊
  • 538
  • 4
  • 16
  • "I still don't know why this code will cause closure" A closure is any variable that is used in your function that comes from an enclosing. In this case, the function is `lambda i: M_list[i]`, and the enclosing scope is the list comprehension. (In older versions of Python, `i` would leak into the global scope, so later code that assigns to a global `i` would also affect the lambda.) – Karl Knechtel Aug 18 '22 at 02:25