3

I'm trying to implement a case-branch within the training loop for an ensemble of autoencoders: depending on a certain condition only one specific autoencoder should be updated. I have been trying to implementent this by using tf.case(), but it's not working as I expected...

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), lambda: f(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

winner_index: Index of Autoencoder to update

f(k_win): returns all the update callables for specific AE-index

Case_List: contains pairs of booleans and parametrized functions

Execution_List: callable for sess.run() in execution-loop.

The parameter k in the for-loop should define the Case_List, specifically the 'lambda: f(k)', but it seems, after building up the list, all 'lambda: f(k)' are set to the last k=N_Classes-1: the effect is, that only the last autoencoder will be updated, and not that one with 'winner_index'. Does anyone have any idea, what's going on here...?

Thanks.

meridius
  • 151
  • 1
  • 1
  • 5

1 Answers1

3

The problem is that the lambdas that you are defining are using the global variable k which, by the time the function is called, has the last value that it took in the loop (N_Class - 1).

A more simple example:

lst = []
for k in range(10):
    lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])

Gives:

[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]

Instead of:

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

This answer explains the issue better and points out a couple of methods to overcome this. In your case, you can do something like this:

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))
Community
  • 1
  • 1
jdehesa
  • 58,456
  • 7
  • 77
  • 121