1

I am looking for the analogous of np.delete in tensorflow - so I have batches of tensors - each batch has shape (batch_size, variable_length) and I want to get a tensor of shape (batch_size, 2 * variable_length / 3). As seen each batch has a different length which is stored and read from the tfrecord. I am a bit at a loss here on what API I should use for that. Related (for numpy):

where the solution would simply be np.delete(x, slice(2, None, 3)) (after performing a reshape to cater for batch_size)

As requested in the comments I post the code for parsing a single example proto - although I am interested in deleting the nth (3rd) element of a tensor as a standalone question.

@classmethod
def parse_single_example(cls, example_proto):
    instance = cls()
    features_dict = cls._get_features_dict(example_proto)
    instance.path_length = features_dict['path_length']
    ...
    instance.coords = tf.decode_raw(features_dict['coords'], DATA_TYPE) # the tensor
    ...
    return instance.coords, ...

@classmethod
def _get_features_dict(cls, value):
    features_dict = tf.parse_single_example(value,
        features={'coords': tf.FixedLenFeature([], tf.string),
                  ...
                  'path_length': tf.FixedLenFeature([], tf.int64)})
    return features_dict
Mr_and_Mrs_D
  • 32,208
  • 39
  • 178
  • 361
  • How does your input pipeline look like? – GPhilo Nov 03 '17 at 22:02
  • @GPhilo I am reading the tensors from tfrecords - but I am interested in just that deleting - so imagine I just have a 1-d tensor (obtained somehow) and I want to delete every third element. – Mr_and_Mrs_D Nov 04 '17 at 11:56
  • *how* you read those records is important. Please post your input pipeline as code in your question – GPhilo Nov 04 '17 at 12:20
  • I don't see why it's important - however here is the code – Mr_and_Mrs_D Nov 04 '17 at 13:35
  • Thanks. It is important because, for example, the tf.data API has methods that allow to do just what you need very easily. On the other hand, doing this on tensors is not as straightforward. – GPhilo Nov 04 '17 at 14:30
  • @GPhilo - yes seems there is no straightforward way of doing this for a tensor - however a way using the tf.data API would also be interesting. The analogous of np.delete in tensorflow would be the ideal solution. – Mr_and_Mrs_D Nov 05 '17 at 23:44

2 Answers2

1

Here is a way avoiding the tf.py_func:

import numpy as np
import tensorflow as tf

slices = ([[1, 2, 3, 4, 5, 6]], [2])
d = tf.contrib.data.Dataset.from_tensor_slices(slices)
d = d.map(lambda coords, _pl: tf.boolean_mask(coords, tf.tile(
  np.array([True, True, False]), tf.reshape(tf.cast(_pl, tf.int32), [1]))))

it = d.make_one_shot_iterator()

with tf.Session() as sess:
  print(sess.run(it.get_next()))
  # [1 2 4 5]

Like all things tensorflow was a bit hard to get right - note the cast (tile fails for int64 'multiples' parameter (which was the length type I read from the tf records)), and the rather unintuitive reshape needed. Generalizing this example to accept variable length arrays is left as an exercise.

I would be interested in a gather_nd version of this code.

Mr_and_Mrs_D
  • 32,208
  • 39
  • 178
  • 361
0

Disclamer: Since you do not provide a minimum, complete and verifiable example, my code cannot be fully tested. You'll need to try and adapt it to your needs.

This is how you could do it using the tf.data API. Please note that since you're not showing the whole layout of your class, I have to make some assumptions on how and where your data is accessible.

First of all, I'm assuming that your class' constructor knows where the .tfrecord files are stored. Specifically, I'll assume that TFRECORD_FILENAMES is a list containing all the file paths to the files you want to extract the records from.

In your class constructor, you need to instantiate a TFRecordDataset and map() on it functions that modify the data the dataset contains:

class MyClass():
    def __init__(self):
        # more init stuff
        def parse_example(serialized_example):
            features_dict = tf.parse_single_example(value,
              features={'coords': tf.FixedLenFeature([], tf.string),
              ...
              'path_length': tf.FixedLenFeature([], tf.int64)})
            return features_dict

        def skip_every_third_pyfunc(coords):
            # you mention something about a reshape, I guess that goes here as well
            return np.delete(coords, slice(None, None, 3)) 

        self.dataset = (tf.data.TFRecordDataset(TFRECORD_FILENAMES)
                        .map(parse_example)
                        .map( lambda features_dict : { **features_dict, **{'coords': tf.py_func(skip_every_third_pyfunc, features_dict['coords'], features_dict['coords'].dtype)} } )
        self.iterator = self.dataset.make_one_shot_iterator() # adapt this to your needs
        self.features_dict = self.iterator.get_next() # I'm putting this here because I don't know where you'll need it

Note that in skip_every_third_pyfunc you can use numpy functions because we're using tf.py_func to wrap a python function as a tensor operation (all caveats in the link apply).

The ugly lambda in the second .map() call is necessary because you're using a feature dict instead of returning a tuple of tensors. py_func's argument takes numpy arrays as input and returns numpy arrays. To keep the dict format, we use python 3.5+ ** operator. If you're using older versions of python, you can define your own merge_two_dicts function and replace it in the lambda call as per this answer.

GPhilo
  • 18,519
  • 9
  • 63
  • 89
  • I _am_returning a tuple of tensors from `parse_single_example` - so this merging of dicts is not needed - will edit OP (so edit out the dicts merging). I am testing this - still a bit disappointing I have to use `tf.py_func ` – Mr_and_Mrs_D Nov 07 '17 at 17:06
  • `TypeError: py_func() missing 2 required positional arguments: 'inp' and 'Tout' ` – Mr_and_Mrs_D Nov 07 '17 at 21:03
  • I think you should use `[features_dict['coords']]` plus get rid of those dicts as I commented above – Mr_and_Mrs_D Nov 07 '17 at 22:38
  • Also another possible approach would be to somehow gather (nd ?) needed elements - this would avoid the py_func hack – Mr_and_Mrs_D Nov 08 '17 at 17:52