0

I'm working on training CLIP model. Here's the source code of the model https://github.com/openai/CLIP/blob/main/clip/model.py

Basically the CLIP object is constructed like this :

class CLIP(nn.module):
   ...
   def encode_image(self, image):
     return self.visual(image.type(self.dtype))

   def encode_text(self, text):
    x = ... 
    ...
    return x

   def forward(self, image, text):
     image_features = self.encode_image(image)
     text_features = self.encode_text(text)
     ...
     return logits_per_image, logits_per_text

The forward method except pair of image and text, since I want to repurpose CLIP for other task(text-text pairs), I'm not using forward from CLIP, but I'm using others method defined inside CLIP. My training code look like this :

for k in range(epoch):
  for batch in dataloader :
    x,y = batch
    y1 = model.encode_text(x[first_text_part])
    y2 = model.encode_text(x[second_text_part])
    <calculate loss, backward, step, etc>

The problem is, after 1 epoch, all the gradients turn out to be nan even though the loss is not nan.
My suspicion is PyTorch only able to propagate the gradient through the forward method.
Some source says that forward is not that special (https://discuss.pytorch.org/t/must-use-forward-function-in-nn-module/50943/3), but other source say coding with torch must use the forward (https://stackoverflow.com/a/58660175/12082666).

The question is, can we train Pytorch network without using forward method?

Vinson Ciawandy
  • 996
  • 11
  • 26

1 Answers1

1

The forward() in pytorch in nothing new. It just attaches the graph of your network when called. Backpropagation doesnt rely much on forward() because, the gradients are propagated through the graph.

The only difference is that in pytorch source, forward is similar to call() method with all the hooks registered in nn.Module.

Nivesh Gadipudi
  • 486
  • 5
  • 15