12

Keras fit_generator() has a parameter pickle_safe which defaults to False.
Training can run faster if it is pickle_safe, and accordingly set the flag to True ?

According to Kera's docs:

pickle_safe: If True, use process based threading. Note that because this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes.

I don't understand exactly what this is saying.
How can I determine if my arguments are pickle_safe or not ??

If it's relevant:
- I'm passing in a custom generator
- the generator function takes arguments: X_train, y_train, batch_size, p_keep;
they are of type np.array, int, float)
- I'm not using a GPU
- Also, I'm using Keras 1.2.1, though I believe this argument behaves the same as in keras 2

SherylHohman
  • 16,580
  • 17
  • 88
  • 94

1 Answers1

17

I have no familiarity with keras, but from a glance at the documentation, pickle_safe just means that the tuples produced by your generator must be "picklable".

pickle is a standard python module that is used to serialize and unserialize objects. The standard multiprocessing implementation uses the pickle mechanism to share objects between different processes -- since the two processes don't share the same address space, they cannot directly see the same python objects. So, to send objects from process A to process B, they're pickled in A (which produces a sequence of bytes in a specific well-known format), the pickled format is then sent via an interprocess-communication mechanism to B, and unpickled in B, producing a copy of A's original object in B's address space.

So, to discover if your objects are picklable, just invoke, say, pickle.dumps on them.

>>> import pickle
>>> class MyObject:
...    def __init__(self, a, b, c):
...      self.a = a
...      self.b = b
...      self.c = c
...
>>> foo = MyObject(1, 2, 3)
>>> pickle.dumps(foo)
b'\x80\x03c__main__\nMyObject\nq\x00)\x81q\x01}q\x02(X\x01\x00\x00\x00cq\x03K\x03X\x01\x00\x00\x00aq\x04K\x01X\x01\x00\x00\x00bq\x05K\x02ub.'
>>>

dumps produces a byte string. We can now reconstitute the foo object from the byte string as bar using loads:

>>> foo_pick = pickle.dumps(foo)
>>> bar = pickle.loads(foo_pick)
>>> bar
<__main__.MyObject object at 0x7f5e262ece48>
>>> bar.a, bar.b, bar.c
(1, 2, 3)

If something is not picklable, you'll get an exception. For example, lambdas can't be pickled:

>>> class MyOther:
...   def __init__(self, a, b, c):
...     self.a = a
...     self.b = b
...     self.c = c
...     self.printer = lambda: print(self.a, self.b, self.c)
...
>>> other = MyOther(1, 2, 3)
>>> other_pick = pickle.dumps(other)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: Can't pickle local object 'MyOther.__init__.<locals>.<lambda>'

See the documentation for more info: https://docs.python.org/3.5/library/pickle.html?highlight=pickle#what-can-be-pickled-and-unpickled

Gil Hamilton
  • 11,973
  • 28
  • 51
  • Thank you for a method to check if it is pickelable. I'd read the docs, but processing, streams, threads, and the like I just don't yet understand. Your suggestion is a great way to test! – SherylHohman May 05 '17 at 00:45
  • It turns out that simply setting `pickle_safe=True` will also throw an error (at least in my case) the object is in fact Not pickle_safe. I had presumed there would be some sort of silent fail.. – SherylHohman May 05 '17 at 00:48