To clarify the title, I currently have a machine learning program where I feed data through a PyTorch DataLoader object. Each data sample is around 335KB in terms of size and I have many of these, so I noticed that the system RAM keeps reaching its limit when I try to all samples.
The idea that I had was to pre-make the input features, save them in smaller chunks, and load them sequentially into the DataLoader object.
Right now the feature creating and saving function I have looks like this:
def create_and_save_data(files, savepath):
for file in files:
with open(file=file) as f:
original_data = json.load(fp=f)
features = create_features(original_data)
chunked_features = split_into_chunks(features)
for idx, feature in enumerate(chunked_features):
save_filename = os.path.join(savepath, f'feature_{idx}.pt')
torch.save(obj=feature, f=save_filename)
Processing one file takes up around 20-30G of RAM. I thought that after processing one file, the RAM would increase by that amount, go back down to the original starting point, then increase again and follow this loop pattern. However, I've noticed that the memory doesn't return to its original state and keeps accumulating 20-30G of RAM for each file.
Is there a way where I can achieve what I want without using too much RAM? I'm a little confused because I thought that this would work and am wondering why the program doesn't get rid of the RAM it used to create one file. Thanks.
Edit
I've taken a look at the question How can I explicitly free memory in Python? and have tried the method of explicitly calling gc.collect()
. This seems to bring the RAM down by only 1G rather than to the original amount.
Running it this way actually gives me the following error:
Traceback (most recent call last):
File "./main.py", line 171, in <module>
main(args)
File "./main.py", line 55, in main
create_and_save_features(data_files=train_data_files, savepath=train_feature_dir, tokenizer=tokenizer)
File "/home/user/github/research/preprocess.py", line 88, in create_and_save_features
torch.save(obj=feature, f=filename)
File "/home/user/anaconda3/envs/userconda/lib/python3.8/site-packages/torch/serialization.py", line 373, in save
return
File "/home/user/anaconda3/envs/userconda/lib/python3.8/site-packages/torch/serialization.py", line 259, in __exit__
self.file_like.write_end_of_file()
RuntimeError: [enforce fail at inline_container.cc:274] . unexpected pos 2341598464 vs 2341598384