Scenario:
I am encountering problems with a Flask app that tries to save a Tensorflow model from disk into memory (a global variable generator
) inside load_model()
when the app first starts.
When a user visits the /test
endpoint, the preloaded model generator
will be used to generate some data.
Problem:
generator.run()
works properly when it is being called in load_model()
when the app first starts. However the exact same code for generator.run()
throws an error when it is called inside test()
.
I managed to narrow down the problem occurring at the tf.compat.v1.Session()
variable sess
.
When called from load_model()
, sess
is <tensorflow.python.client.session.Session object at 0x12d317fd0>
. But when called from test()
, sess
is None
.
Anyone knows how to solve this problem? It is highly preferred to load the huge model to memory just once (~10s load time), then load it everytime the endpoint is queried.
Thank you!
from flask import Flask
app = Flask(__name__)
generator = None
# EVERYTHING WORKS WELL HERE
def load_model():
tflib.init_tf()
# Load model from disk
with open(model_path, "rb") as f:
_G, _D, Gs = pickle.load(f, encoding='latin1')
# Update global variable generator
global generator
generator = Gs
# Run model
latent = np.random.randn(1, generator.input_shape[1])
img = generator.run(latent)[0] # `sess` is <tensorflow.python.client.session.Session object at 0x12d317fd0>
print(img.shape) # prints: (512, 512, 3)
# PROBLEM OCCURS HERE
@app.route('/test')
def test():
print('generator: ', generator) # prints: <dnnlib.tflib.network.Network object at 0x13a80e5d0>
print('generator.run: ', generator.run) # prints: <bound method Network.run of <dnnlib.tflib.network.Network object at 0x13a80e5d0>>
# Run model
latent = np.random.randn(1, generator.input_shape[1])
img = generator.run(latent)[0] # `sess` is None
if __name__ == '__main__':
load_model() # Loads the model from disk
app.run()
Error after visiting /test
[2019-09-18 23:26:33,083] ERROR in app: Exception on /test [GET]
Traceback (most recent call last):
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 2446, in wsgi_app
response = self.full_dispatch_request()
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1951, in full_dispatch_request
rv = self.handle_user_exception(e)
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1820, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
raise value
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1949, in full_dispatch_request
rv = self.dispatch_request()
File "/anaconda3/envs/ml/lib/python3.7/site-packages/flask/app.py", line 1935, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "app.py", line 71, in test
img = generator.run(latent)[0]
File "/Users/x/foo/dnnlib/tflib/network.py", line 460, in run
mb_out = sess.run(out_expr, dict(zip(in_expr, mb_in)))
AttributeError: 'NoneType' object has no attribute 'run'
Code around sess.run()
[Full file network.py
]
...
for mb_begin in range(0, num_items, minibatch_size):
if print_progress:
print("\r%d / %d" % (mb_begin, num_items), end="")
mb_end = min(mb_begin + minibatch_size, num_items)
mb_num = mb_end - mb_begin
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
#跑
# [<tf.Tensor 'Gs/_Run/concat:0' shape=(?, 3, 512, 512) dtype=float32>]
# <tf.Tensor 'Gs/_Run/labels_in:0' shape=<unknown> dtype=float32>: array([], shape=(1, 0), dtype=float64)
# <tf.Tensor 'Gs/_Run/latents_in:0' shape=<unknown> dtype=float32>: latents.shape (1, 512)
#mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
sess = tf.get_default_session()
#writer = tf.summary.FileWriter("logs/", sess.graph)
mb_out = sess.run(out_expr, dict(zip(in_expr, mb_in)))
for dst, src in zip(out_arrays, mb_out):
dst[mb_begin: mb_end] = src
...