I've tried a lot of approaches to training models concurrently with the multiprocessing module in a SageMaker instance.
Either multiple processes spawn, and they crash when they hit the train() call, or the processes run sequentially (sometimes crashing after 1 process runs its course).
I've run a few examples of multiprocessing that generate random arrays to make sure multiprocessing is working in general. This leads me to believe there's an issue either in the way I'm handling data or tensorflow in conjunction with the async functionality.
I stripped everything down to mimic the code here, but my print statements following the async map/before the wait do not print:
def trainNetworks(numProcesses=1):
if __name__ == '__main__':
batches = []
netInCt = len(master_in.network_inputs)
batchSize = math.ceil(netInCt / numProcesses)
for i in range(numProcesses):
batchStartIdx = (i * batchSize) # inclusive idx
batchEndIdx = ((i + 1) * batchSize) if (((i + 1) * batchSize) < netInCt) else netInCt # exclusive idx
netBatch = copy.deepcopy(master_in.network_inputs[batchStartIdx:batchEndIdx])
batches.append(netBatch)
po = multiprocessing.Pool(numProcesses)
batchOut = po.map_async(trainNetBatch, batches)#.get()
# get will start the processes and execute them
print('abc')
po.wait()
print('zyx')
po.terminate()
# pickle master output
print('All done, pickling')
return 0
def trainNetBatch(netBatch):
for netIn in netBatch:
print('!!!starting ', batchStartIdx)
x = np.array(netIn.table)
y = np.array(netIn.labels)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 0)
ann = getNetModel()
print('!!!fitting ', batchStartIdx)
# train network
ann.fit(x_train, y_train, batch_size = 32, epochs = 100, verbose = 0)
print('!donefit ', batchStartIdx)
The instance CPU usage and memory never exceed ~30%, but the execution hangs.
In a few cases, I've had training execute for a few models, then a syntax error from BEFORE the fit() call halts execution. How would the code get past the error to call fit in the first place?