4

I am using tensorflow 2.3.0

I have a python data generator-

import tensorflow as tf
import numpy as np

vocab = [1,2,3,4,5]

def create_generator():
    'generates a random number from 0 to len(vocab)-1'
    count = 0
    while count < 4:
        x = np.random.randint(0, len(vocab))
        yield x
        count +=1

I make it a tf.data.Dataset object

gen = tf.data.Dataset.from_generator(create_generator, 
                                     args=[], 
                                     output_types=tf.int32, 
                                     output_shapes = (), )

Now I want to sub-sample items using the map method, such that the tf generator would never output any even number.

def subsample(x):
    'remove item if it is present in an even number [2,4]'
    
    '''
    #TODO
    '''
    return x
    
gen = gen.map(subsample)   

How can I achieve this using map method?

n0obcoder
  • 649
  • 8
  • 24

1 Answers1

7

Shortly no, you cannot filter data using map. Map functions apply some transformation to every element of the dataset. What you want is to check every element for some predicate and get only those elements that satisfy the predicate.

And that function is filter().

So you can do:

gen = gen.filter(lambda x: x % 2 != 0)

Update:

If you want to use a custom function instead of lambda, you can do something like:

def filter_func(x):
    if x**2 < 500:
        return True
    return False
gen = gen.filter(filter_func)

If this function is passed to filter all numbers whose square is less than 500 will be returned.

  • Wow! This was helpful. Can I use a custom python function to filter out the elements of the dataset based on some custom rules too? – n0obcoder Nov 17 '20 at 05:41
  • The real thing that I want to do is, sub-sample items from an item list, based on their frequency of occurrence. An item list can be like [1,3,4,6,1,3,9], now lets say after applying the subsampling, the item list is reduced to [1,3,6,1]. Next thing I want to do is discard all the sequences with length less than 2. If length is >= 2, I want to use it . How can I pull this off? – n0obcoder Nov 17 '20 at 05:49
  • 1
    Yes, as long as your custom function returns boolean value. I will update the answer – Muslimbek Abduganiev Nov 17 '20 at 05:55
  • But, the boolean will only decide if that particular dataset element will be filtered or not. I am not sure how can I transform the dataset element using a custom python function, and filter out the sequence(or dataset element) if the sequence length is <2. Can you please help me with this? Can I even use the filter method to do so? – n0obcoder Nov 17 '20 at 06:00
  • for your particular problem, `lambda x: gen.count(x) > 2`. But your `gen` must be a variable in the outer scope. – Muslimbek Abduganiev Nov 17 '20 at 06:07