Finding the maximum batch size is a cumbersome and often time-consuming process. However, there are some solutions that propose approximate solutions, such as (How to calculate optimal batch size). I came up with a method to iteratively find the maximum batch size my GPU can handle without running out of memory.
My current approach involves running a subprocess that calls my training script experiment.py with a decreasing batch size until it runs successfully. Finally, it gives me the maximum batch size. If you want to use this method and you have a large dataset where preprocessing would take some time, I would recommend considering a small subset of the dataset and skipping the preprocessing steps. This will ensure that the loop over the multiple runs will run faster. Here is the snippet from main.py to be run where I considered 500 as a batch size that would definitely fail:
import subprocess
def run_experiment(batch_size):
try:
result = subprocess.run(["python3", "experiment.py", "--batch_size", str(batch_size)], check=True, text=True, capture_output=True)
print(f"Batch size {batch_size} succeeded")
return True
except subprocess.CalledProcessError as e:
if "CUDA out of memory" in e.stderr:
print(f"Batch size {batch_size} failed due to CUDA out of memory")
return False
else:
print(f"An unexpected error occurred: {e.stderr}")
return False
def find_max_batch_size(start, step):
for i in range(start, 0, -step):
if run_experiment(i):
return i
return start
# Run the experiments with different step sizes
batch_size = find_max_batch_size(500, 100)
batch_size = find_max_batch_size(batch_size + 100, 10)
batch_size = find_max_batch_size(batch_size + 10, 1)
print("Maximum batch size : ", batch_size)
In experiment.py, I've added command line argument parsing for batch size:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, required=True, help="Batch size for the experiment")
args = parser.parse_args()
batch_size = args.batch_size
While this works to a degree, it is not very efficient. I'm wondering if there is a more optimal way to dynamically calculate or estimate the maximum batch size a GPU can handle based on its available memory?
Any insights or suggestions?