3

While building some code to train a tensorflow deep model, I am using tensorflow tf.map_fn and tf.py_function as a wrapper to apply a scipy python function as a loss function mapping each 2 rows of a batch of 2 probability vectors p and q of shape [batch_size,num_classes]. When using KL_divergence over this batch of vectors (p,q), the training works fine with this computation and there is no shape incompatibility issue:

tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1) #KL divergence

However, when I tried to use Wasserstein distance or the energy_distance functions from scipy, I get an error dealing with incompatible shapes [] and [5000]. 5000 is here the number of classes (p and q of shape [batch_size, 5000])

import tensorflow as tf

def compute_kld(p_logit, q_logit, divergence_type):
    p = tf.nn.softmax(p_logit)
    q = tf.nn.softmax(q_logit)
    if divergence_type == "KL_divergence":
        return tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1)
    elif divergence_type == "Wasserstein_distance":
        def wasserstein_distance(x,y):
            import scipy
            from scipy import stats
            return stats.wasserstein_distance(x,y)
        @tf.function
        def func(p,q):
            return tf.map_fn(lambda x: tf.py_function(func=wasserstein_distance, inp=[x[0], x[1]], Tout=tf.float32), (p, q), dtype=(tf.float32)) #, parallel_iterations=10)
        return func(p, q)
    elif divergence_type == "energy_distance": # The Cramer Distancedef energy_distance(x,y):
        def energy_distance(x,y):
            import scipy
            from scipy import stats
            return stats.energy_distance(x,y)

        @tf.function
        def func(p,q):
            return tf.map_fn(lambda x: tf.py_function(func=energy_distance, inp=[x[0], x[1]], Tout=tf.float32), (p, q), dtype=(tf.float32)) #, parallel_iterations=10)
        return func(p, q)

This is the code to test the loss functions with a batch of 5 and 3 classes, which all work fine individually:

import tensorflow as tf
p = tf.constant([[1, 2, 3], [1, 2, 3], [14, 50, 61], [71, 83, 79], [110,171,12]])
q = tf.constant([[1, 2, 3], [1.2, 2.3, 3.2], [4.2, 5.3, 6.4], [7.5, 8.6, 9.4], [11.2,10.1,13]])

p = tf.reshape(p, [-1,3])
q = tf.reshape(q, [-1,3])

p = tf.cast(p, tf.float32)
q = tf.cast(q, tf.float32)
with tf.Session() as sess:
    divergence_type = "KL_divergence"
    res = compute_kld(p, q, divergence_type = divergence_type)
    
    divergence_type = "Wasserstein_distance"
    res2 = compute_kld(p, q, divergence_type = divergence_type)
    
    divergence_type = "energy_distance"
    res3 = compute_kld(p, q, divergence_type = divergence_type)
    print("############################## p")   
    print(sess.run(tf.print(p)))
    print("##")
    print(sess.run(tf.print(tf.shape(p))))
    print("############################## KL_divergence")   
    print(sess.run(tf.print(res)))
    print("##")
    print(sess.run(tf.print(tf.shape(res))))
    print("############################## Wasserstein_distance")   
    print(sess.run(tf.print(res2)))
    print("##")
    print(sess.run(tf.print(tf.shape(res2))))
    print("############################## energy_distance")   
    print(sess.run(tf.print(res3)))
    print("##")
    print(sess.run(tf.print(tf.shape(res3))))

This is the output:

############################## p
[[1 2 3]
 [1 2 3]
 [14 50 61]
 [71 83 79]
 [110 171 12]]
None
##
[5 3]
None
############################## KL_divergence
[0 0.000939823687 0.367009342 1.1647588 3.09911442]
None
##
[5]
None
############################## Wasserstein_distance
[0 0.0126344115 0.204870835 0.237718046 0.120362818]
None
##
[5]
None
############################## energy_distance
[0 0.0917765796 0.41313991 0.438246906 0.316672504]
None
##
[5]
None

However, when using the wasserstein distance or the energy distance inside my training code, I get incompatible shape error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Tried to set a tensor with incompatible shape at a list index. Item element shape: [] list shape: [5000]
         [[{{node gradients/TensorArrayV2Read/TensorListGetItem_grad/TensorListSetItem}}]]

I am wondering if the dtype for tf.map_fn or tf.py_function I am using is wrong or if I have to specify/impose shape somewhere ?

Here is a link for the whole code where I tried to replace KL-divergence with Wasserstein distance in method "compute_kld": https://github.com/shenyuanyuan/IMSAT/blob/master/imsat_cluster.py

Thank you in advance for your kind help!

== UPDATE ==

I inspected all the provided batches and the shapes of p and q seem correct

shape(p)
(?, 5000)
shape(q)
(?, 5000)

However, the type of func's returned object is . Thus, I have tried to reshape it with:

return tf.reshape(func(p, q), [p.shape[0]])

However, this doesn't seem to change anything as the error is still the same. After providing the first batch, the code crashes before starting to process the second batch.

Othmane
  • 1,094
  • 2
  • 17
  • 33
  • check the shape of your tensors, are they the same for all tensors? or is there any mismatch in the list? this error indicates this your input shapes are not the same for all tensors. can you please share your input shapes and size? – Mohammad Ahmed Feb 02 '23 at 09:48
  • Normally the shape of provided tensors in each batch is supposed to be [batch_size, 5000] for both p and q tensors. Except maybe for the last batch which could be [] according to the error message. But tensorflow is normally supposed to handle that automatically I think. I will try to keep printing all shapes of batches – Othmane Feb 02 '23 at 17:24
  • @MohammadAhmed I have inspected all shapes by printing their shapes and they are all shape(p) (?, 5000) shape(q) (?, 5000) which is good. However, after the first batch the code crashes – Othmane Feb 04 '23 at 11:35
  • why you just don't write these functions using TensorFlow by default TensorFlow uses graph execution, therefore, it fails to convert other libraries' functions to graph and you are explicitly converting these functions to graph so then it is hard to catch the actual error, or just enable the `eager execution` and then check what the error is by enabling `debugging_mode()`. – Mohammad Ahmed Feb 04 '23 at 13:00

1 Answers1

0

Without seeing your training code, what I can help is to fetch the docs and try to shed some light.

map_fn Transforms elems by applying fn to each element unstacked on axis 0.

If elems is a tuple (or nested structure) of tensors, then those tensors must all have the same outer-dimension size (num_elems); and fn is used to transform each tuple (or structure) of corresponding slices from elems. E.g., if elems is a tuple (t1, t2, t3), then fn is used to transform each tuple of slices (t1[i], t2[i], t3[i]) (where 0 <= i < num_elems).

energy_distance Computes the energy distance between two 1D distributions.

wasserstein_distance Computes the first Wasserstein distance between two 1D distributions.

To begin, you hould make sure you are passing only 2D p_logit and q_logit to compute_kld.

Bob
  • 13,867
  • 1
  • 5
  • 27
  • Dear Bob, thanks for trying to help. Indeed, p_logit and q_logit are 2D matrices. Here is the the whole code where I try to replace the KL-Divergence with Wasserstein distance: https://github.com/shenyuanyuan/IMSAT/blob/master/imsat_cluster.py – Othmane Jan 26 '23 at 23:33
  • Just to remind you, the code for wasserstein distance and energy distance works when tested separately as provided above. However, when I plug in the same code in the IMSAT code, I get the above error about incompatible shapes. Please have a look at the link I provided for the whole code to see if something is wrong with the shapes or if I need to specify it more somewhere else. The whole code is for clustering. – Othmane Jan 26 '23 at 23:36
  • I update my question with the link for the whole clustering code @Bob – Othmane Jan 26 '23 at 23:47
  • I have inspected all shapes by printing their shapes and they are all shape(p) (?, 5000) shape(q) (?, 5000) which are good 2D vectors as inspected. However, after the first batch the code crashes before starting to process the second batch. Thus, the error seems to be triggered at the end of the first batch right before starting to process the second one. – Othmane Feb 04 '23 at 11:44