I am implementing a triplet network in Pytorch where the 3 instances (sub-networks) share the same weights. Since the weights are shared, I implemented it as a single instance network that is called three times to produce the anchor, positive, and negative embeddings. The embeddings are learned by optimizing the triplet loss. Here is a small snippet for illustration:
from dependencies import *
model = SingleSubNet() # represents each instance in the triplet net
for epoch in epochs:
for anch, pos, neg in enumerate(train_loader):
optimizer.zero_grad()
fa, fp, fn = model(anch), model(pos), model(neg)
loss = triplet_loss(fa, fp, fn)
loss.backward()
optimizer.step()
# Do more stuff ...
My complete code works as expected. However, I do not understand what does the loss.backward()
compute the gradient(s) in this case. I am confused because there are 3 gradients of loss is in each learning step (the gradients formulas are here). I assume the gradients are summed before performing optimizer.step()
. But then it looks from the equations that if the gradients are summed, they will cancel each other out and yield zero update term. Of course, this is not true as the network learns meaningful embeddings at the end.
Thanks in advance