0

I'm trying to access/save the logfile to plot the losses and other metrics displayed while training my torchvision mask rcnn model.

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

Epoch: [0]  [0/4]  eta: 0:01:11  lr: 0.000334  loss: 1.4949 (1.4949)  loss_classifier: 0.4294 (0.4294)  loss_box_reg: 0.3279 (0.3279)  loss_mask: 0.5124 (0.5124)  loss_objectness: 0.0564 (0.0564)  loss_rpn_box_reg: 0.1689 (0.1689)  time: 17.8587  data: 10.8542  max mem: 5687

As far as I know, the metrics are produced by the metric_logger (engine.py) and I should be able to save them by adding a custom_callback arg to my model.train()

model.train(custom_callbacks = callbacks)
for epoch in range(num_epochs):
    print ('+++++++++++++++++++++++++++')
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

at this point, nothing is working, except showing a tensorboard page with no metrics to display. I'll be happy by just appending the losses in a list...

thanks

Gil
  • 77
  • 1
  • 2
  • 8
  • Why do you think that you "_should be able to save them by adding a custom_callback arg to my model.train()_"? – Berriel Jun 06 '20 at 21:32
  • Hello Berriel : I've read that by passing argument *custom_callback* to the model.train() function, pytorch would return the values of the differents losses. Unfortunately, this method does not work at all – Gil Jun 06 '20 at 21:38
  • Are you using any kind of framework on top of PyTorch? [`.train()`](https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch) is supposed to accept `mode:bool` only. Trying to figure out where you read that. It is really hard to provide any guidance, since the relevant code is hidden inside your custom `train_one_epoch`. – Berriel Jun 06 '20 at 22:02

1 Answers1

0

I was facing same problem and I used monkey patching to override the train_one_epoch function to return metric_logger instance:

from engine import train_one_epoch, evaluate
import math

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    ...
    return metric_logger
Shashi Gharti
  • 87
  • 2
  • 5