I'm trying to implement triplet loss for NetVLAD layer, by using 3 different images from my dataloader as follows:
- a - batch with minimal augmentation
- p - same batch with more augmentations
- n - different batch from the same dataloader
However when I try to obtain the VLAD embeddings from my model, I get an out of memory error after computing the first VLAD embedding. Is there a more efficient way to run this code?
EDIT: Perhaps a better way to load the other batch I need for the negatives?
for i, (inputs, poses) in enumerate(train_loader):
inputs = inputs.to(device)
# poses = poses.to(device)
print(inputs.shape) # anchors
positives = positive_tfs(inputs)
positives = positives.to(device)
print(positives.shape)
negatives, _ = next(iter(train_loader)) # different random batch thanks to shuffle=True
negatives = negatives.to(device)
print(negatives.shape)
# Zero the parameter gradient
optimizer.zero_grad()
# FORWARD
a_vlad, pos_out, ori_out = model(inputs)
print(a_vlad.shape)
p_vlad = model(positives, get_pose=False)
print(p_vlad.shape)
n_vlad = model(negatives, get_pose=False)
print(n_vlad.shape)
# COMPUTE LOSS
loss = triplet_loss(a_vlad, p_vlad, n_vlad)
print(loss)