5

Situation: I train for a while, then want to save exactly the current train state to disk, and exit. Then I want to continue training later. And it should be exactly the same behavior, as if I would not have exited.

For simplicity, let's say I use SGD, although storing the updater state (Adam etc) is also not a problem.

However, I don't know how to read and store the random state. So when I recreate the graph next time and a new session, it will not continue the random sequence (either I have done it deterministic, then it will just start as it has started the first time, or it will be random).

So, how can I read the random state? Or a random seed such that if I initialize later with that seed, it would continue with the same sequence?


If that is not possible, maybe there are other random generators I can use instead? I found out about tf.contrib.stateless which seems to provide such. E.g. there I could use sth. like:

tf.contrib.stateless.stateless_random_normal(..., seed=global_step * some_number)
Albert
  • 65,406
  • 61
  • 242
  • 386
  • did you read https://www.tensorflow.org/api_docs/python/tf/set_random_seed ? It seems to be pretty well documented there. – avigil Mar 02 '18 at 15:12
  • 1
    Posting as a comment rather than an answer because I'm not 100% sure (and maybe it changes at some point), but most likely you won't be able to do that. As far as I can see, RNG state is ultimately encapsulated in the C++ class [PhiloxRandom](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/random/philox_random.h), and it does not seem to provide a way to read its current state. You can seed a graph and an op, but not (I think) resume the RNG from a given point. – jdehesa Mar 02 '18 at 17:46
  • @avigil Yes but that does not answer my question. – Albert Mar 05 '18 at 11:20
  • @jdehesa Interesting. Maybe by using the `Skip` function there? But I don't really know a reliable way how got get the exact number to skip. – Albert Mar 05 '18 at 11:24
  • @Albert Yeah that's the thing, member var `counter_` holds the number of skips but it doesn't look like you can get it out (also you'd have to get the value all the way from there to Python or some state-saving functionality, ofc). – jdehesa Mar 05 '18 at 11:39
  • set the graph and operation level seeds with your own generated random numbers, which you can save however you'd like – avigil Mar 05 '18 at 14:48
  • @avigil: Please read my question. So, I train for a while now. What is the seed now of the current state? – Albert Mar 05 '18 at 15:48
  • You can't retrieve that. All you can do is set the seed to some known value. – avigil Mar 05 '18 at 16:05
  • @avigil Well, that was what I'm asking for. So then that would be the answer. Although I'm quite sure that with some effort, maybe ugly hacks, I could get the hidden state from it and also set that. Or, I could use other random generators as alternative. I just added one approach into the question. – Albert Mar 05 '18 at 16:39
  • im sure you could find it somewhere in C, its just not exposed by the public python tensorflow API. – avigil Mar 05 '18 at 16:55
  • @avigil: Well, that's my question (or just any way how to do what I'm asking for). I could also then write a custom TF op kernel which accesses this. If you could provide proof-of-concept code to do so, that would be great. – Albert Mar 06 '18 at 10:13
  • I posted an answer to [this similar question](https://stackoverflow.com/q/52333517/1782792) that may be relevant to anyone stumbling upon this one. – jdehesa Sep 14 '18 at 15:50

0 Answers0