I was trying to implement a massively parallel Differential Equation solver (30k DEs) on Tensorflow CPU but was running out of memory (Around 30GB matrices). So I implemented a batch based solver (solve for small time and save data -> set new initial -> solve again). But the problem persisted. I learnt that Tensorflow does not clear the memory until the python interpreter is closed. So based on info on github issues I tried implementing a multiprocessing solution using pool but I keep getting a "can't pickle _thread.RLock objects" at the Pooling step. Could someone please help!
def dAdt(X,t):
dX = // vector of differential
return dX
global state_vector
global state
state_vector = [0]*n // initial state
def tensor_process():
with tf.Session() as sess:
print("Session started...",end="")
tf.global_variables_initializer().run()
state = sess.run(tensor_state)
sess.close()
n_batch = 3
t_batch = np.array_split(t,n_batch)
for n,i in enumerate(t_batch):
print("Batch",(n+1),"Running...",end="")
if n>0:
i = np.append(i[0]-0.01,i)
print("Session started...",end="")
init_state = tf.constant(state_vector, dtype=tf.float64)
tensor_state = tf.contrib.odeint_fixed(dAdt, init_state, i)
with Pool(1) as p:
p.apply_async(tensor_process).get()
state_vector = state[-1,:]
np.save("state.batch"+str(n+1),state)
state=None