For example, if we have:
a = tf.constant(np.eye(5))
a
<tf.Tensor 'Const:0' shape=(5, 5) dtype=float64>
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(5,) dtype=float64>
The slice of tensor a
will reduce the original number of dimension 2
to 1
How could I just directly get the sliced with rank not changed like:?
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(1,5) dtype=float64>
(tf.expand_dims(a[0,:], axis=0)
could work, but are there more direct and easy way?)