import tensorflow as tf
a = tf.constant([1,2,3])
b = tf.constant([4,5,6,7])
c = tf.constant([8,9,10])
def cart_prod(a,b,c):
tile_a = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[0]])
tile_a = tf.expand_dims(tile_a, 2)
tile_b = tf.tile(tf.expand_dims(b, 0), [tf.shape(a)[0], 1])
tile_b = tf.expand_dims(tile_b, 2)
cart = tf.concat([tile_a, tile_b], axis=2)
cart = tf.reshape(cart,[-1,2])
tile_c = tf.tile(tf.expand_dims(c, 1), [1, tf.shape(cart)[0]])
tile_c = tf.expand_dims(tile_c, 2)
tile_c = tf.reshape(tile_c, [-1,1])
cart = tf.tile(cart,[tf.shape(c)[0],1])
cart = tf.concat([cart, tile_c], axis=1)
return cart
with tf.Session() as sess:
cart = tf.Session().run(cart_prod(a,b,c))
print(cart)