3

Currently, I load pretrained torchvision model using following code:

import torchvision
torchvision.models.resnet101(pretrained=True)

However, I'd love to have model name as string parameter and then load the pretrained model using that string. A pseudo-code that would do so would be something like:

model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)

Is there a way to accomplish this in a rather simple manner?

Shir
  • 1,571
  • 2
  • 9
  • 27
jakes
  • 1,964
  • 3
  • 18
  • 50

2 Answers2

4

You can use torch.hub:

model_str = 'resnet50'
model = torch.hub.load('pytorch/vision', model_str, pretrained=True)

All the available models by strings can be found via:

torch.hub.list('pytorch/vision', force_reload=True)

output:

['alexnet',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'fcn_resnet101',
 'fcn_resnet50',
 'googlenet',
 'inception_v3',
 'lraspp_mobilenet_v3_large',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']
Shir
  • 1,571
  • 2
  • 9
  • 27
3

You can use getattr

getattr(torchvision.models, 'resnet101')(pretrained=True)

Umang Gupta
  • 15,022
  • 6
  • 48
  • 66