0

Scenario:

I am encountering problems with a Flask app that tries to save a Tensorflow model from disk into memory (a global variable generator) inside load_model() when the app first starts.

When a user visits the /test endpoint, the preloaded model generator will be used to generate some data.

Problem:

generator.run() works properly when it is being called in load_model() when the app first starts. However the exact same code for generator.run() throws an error when it is called inside test().

I managed to narrow down the problem occurring at the tf.compat.v1.Session() variable sess.

When called from load_model(), sess is <tensorflow.python.client.session.Session object at 0x12d317fd0>. But when called from test(), sess is None.

Anyone knows how to solve this problem? It is highly preferred to load the huge model to memory just once (~10s load time), then load it everytime the endpoint is queried.

Thank you!

from flask import Flask

app = Flask(__name__)
generator = None

# EVERYTHING WORKS WELL HERE
def load_model():
    tflib.init_tf()

    # Load model from disk
    with open(model_path, "rb") as f:
        _G, _D, Gs = pickle.load(f, encoding='latin1')

    # Update global variable generator
    global generator
    generator = Gs

    # Run model
    latent = np.random.randn(1, generator.input_shape[1])
    img = generator.run(latent)[0]              # `sess` is <tensorflow.python.client.session.Session object at 0x12d317fd0>
    print(img.shape)                            # prints: (512, 512, 3)


# PROBLEM OCCURS HERE
@app.route('/test')
def test():
    print('generator: ', generator)             # prints: <dnnlib.tflib.network.Network object at 0x13a80e5d0>
    print('generator.run: ', generator.run)     # prints: <bound method Network.run of <dnnlib.tflib.network.Network object at 0x13a80e5d0>>

    # Run model
    latent = np.random.randn(1, generator.input_shape[1])
    img = generator.run(latent)[0]              # `sess` is None


if __name__ == '__main__':
    load_model()                                # Loads the model from disk
    app.run()

Error after visiting /test

[2019-09-18 23:26:33,083] ERROR in app: Exception on /test [GET]
Traceback (most recent call last):
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 2446, in wsgi_app
    response = self.full_dispatch_request()
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1951, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1820, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1949, in full_dispatch_request
    rv = self.dispatch_request()
  File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1935, in dispatch_request
    return self.view_functions[rule.endpoint](**req.view_args)
  File "app.py", line 71, in test
    img = generator.run(latent)[0]
  File "/Users/x/foo/dnnlib/tflib/network.py", line 460, in run
    mb_out = sess.run(out_expr, dict(zip(in_expr, mb_in)))
AttributeError: 'NoneType' object has no attribute 'run'

Code around sess.run() [Full file network.py]

...
        for mb_begin in range(0, num_items, minibatch_size):
            if print_progress:
                print("\r%d / %d" % (mb_begin, num_items), end="")

            mb_end = min(mb_begin + minibatch_size, num_items)
            mb_num = mb_end - mb_begin
            mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
            #跑
            # [<tf.Tensor 'Gs/_Run/concat:0' shape=(?, 3, 512, 512) dtype=float32>]

            # <tf.Tensor 'Gs/_Run/labels_in:0' shape=<unknown> dtype=float32>: array([], shape=(1, 0), dtype=float64)
            # <tf.Tensor 'Gs/_Run/latents_in:0' shape=<unknown> dtype=float32>: latents.shape (1, 512)
            #mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
            sess = tf.get_default_session()
            #writer = tf.summary.FileWriter("logs/", sess.graph)
            mb_out = sess.run(out_expr, dict(zip(in_expr, mb_in)))

            for dst, src in zip(out_arrays, mb_out):
                dst[mb_begin: mb_end] = src
...
Nyxynyx
  • 61,411
  • 155
  • 482
  • 830
  • I think that your issue is that Flask tasks can run in different Threads/Processes. Instead of using `global` variables try some other kind of storage like shared memory. – carobnodrvo Sep 19 '19 at 07:58
  • Check out [this](https://stackoverflow.com/questions/32815451/are-global-variables-thread-safe-in-flask-how-do-i-share-data-between-requests) answer. – carobnodrvo Sep 19 '19 at 08:04

0 Answers0