I tried implementing a custom data loader that will make a web request and will return a sample. My purpose of the program is to see if this idea would be faster than the original data loader. My web server code is run with
srun -n24 --mem = 12g python web.py
Which will then create 24 "workers" that run in the cluster. Then each worker will write its portname to a file to make itself known to the data loader that he exist. So, when the dataloader is called in the training loop. The data loader selects a random server from the files and send them a web request with an index. The web server will then load the sample and do augmentation and return via http response. From my view, i thought it would be faster than the original data loader as, each data loader worker would send a request to the webserver and get a sample. Thus, distributing data to different server so they load the images faster.
However, when i do a comparison with original data using COCO dataset. The original data loader takes 743.820 sec to complete loading an epoch while my custom data loader takes 1503.26 sec to complete. I couldn't figure out which part of my code is taking a long time, so i would like to ask for assistance. Please if my explaination is bad/not great please let me know. Any help is appreciated. Thankyou.
The following the code for starting webserver:
class PytorchDataHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.end_headers()
get_param = self.path
get_param = parse_qs(urlparse(get_param).query)
batch_list = [[],[]]
c_batches = []
index = get_param['index']
if index :
for data in index:
result = imagenet_data[int(data)]
batch_list[0].append(result[0])
batch_list[1].append(result[1])
c_batches.append(batch_list)
torch.save(batch_list, self.wfile)
else:
write_log('Empty Parameter')
def main():
sock = socket.socket(socket.AF_INET,socket.SOCK_DGRAM)
hostname = socket.gethostname()
n_hostname = hostname.split(".")
# Bind to random port
sock.bind(('0.0.0.0', 0))
# Get Port Number
PORT = int(sock.getsockname()[1])
current_dir = os.getcwd()
create_dir = os.path.join(current_dir, r'worker_file')
#filename = create_dir + '/' + str(n_hostname[0]) + '.cvl-tengig:' + str(PORT)
filename = create_dir + '/' + str(n_hostname[0]) + ':' + str(PORT)
os.makedirs(os.path.dirname(filename), exist_ok=True)
open_file = open(filename, 'w')
open_file.write(str(n_hostname[0]) + ':' + str(PORT))
open_file.close()
try :
SERVER = HTTPServer(('', PORT), PytorchDataHandler)
SERVER.serve_forever()
except KeyboardInterrupt:
print('Shutting down server, ^C')
os.remove(filename)
SERVER.socket.close()
if __name__ == '__main__':
main()
The code for custom data loader:
class DistData(Dataset):
def __init__(self, data, transform = None):
self.data = data
# Get file path
current_dir = os.getcwd()
create_dir = os.path.join(current_dir, r'worker_file')
# Get all item in file
self.arr = os.listdir(create_dir)
self.selected = []
def __getitem__(self, index):
# Select a random server
random_server = random.choice(self.arr)
# Remove selected server from the server list
self.arr.remove(random_server)
# Append selected server to the selected list
self.selected.append(random_server)
return self.post_request(index, random_server)
def __len__(self):
return len(self.data)
def post_request(self, index, random_server):
params = {'index': index}
url = 'http://' + random_server + '/get'
r = requests.get(url , params = params)
print("Response Time : {:<10} , worker : {:<10} ".format(r.elapsed.total_seconds(), torch.utils.data.get_worker_info().id ))
# Remove server from selected once there's a response
self.selected.remove(random_server)
# Add back to main server list after response
self.arr.append(random_server)
buffer = io.BytesIO(r.content)
response = torch.load(buffer)
return response
def train(net, device, trainloader, criterion, optimizer):
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
print('Epoch : {}'.format(epoch + 1))
print('----------------------------')
start_time = time.time()
total_time = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
print("Train: Time taken to load batch {} is {}".format(i+1,time.time() - start_time))
total_time += time.time() - start_time
start_time = time.time()
print('Epoch : {} , Total Time Taken : {}'.format(epoch + 1, total_time))
print('Finished Training')
imagenet_data =torchvision.datasets.CocoCaptions('/db/shared/detection+classification/coco/train2017/' ,
'/db/shared/detection+classification/coco/annotations/captions_train2017.json')
training_set = DistData(imagenet_data)
trainloader = DataLoader(training_set, sampler = BatchSampler(RandomSampler(training_set), batch_size = 24, drop_last = False),
num_workers = 4)
train(trainloader)