3

I am using PyTorch to do image classification. After training the most, I want to save the trained model.

I don't understand what is meant by

the_model = TheModelClass(*args, **kwargs)

This line of code is given by PyTorch website (https://pytorch.org/docs/master/notes/serialization.html).

melpomene
  • 84,125
  • 8
  • 85
  • 148
lodepas
  • 29
  • 1
  • 3
  • 2
    This is a generic way to create a new instance of the class `TheModelClass` with some positional arguments (`args`) and named arguments (`kwargs`). It is not particular to pytorch but to python. See [this SO post](https://stackoverflow.com/questions/287085/what-do-args-and-kwargs-mean). – Léopold Houdin Sep 22 '18 at 09:28
  • Possible duplicate of [Understanding kwargs in Python](https://stackoverflow.com/questions/1769403/understanding-kwargs-in-python) – Salih Karagoz Sep 22 '18 at 11:02
  • Possible duplicate of [What does \*\* (double star/asterisk) and \* (star/asterisk) do for parameters?](https://stackoverflow.com/questions/36901/what-does-double-star-asterisk-and-star-asterisk-do-for-parameters) – Daniel Pryden Sep 22 '18 at 11:05
  • You could read this https://ke-zhang-rd.github.io/krystal/starred.html – Ke Zhang Jan 15 '21 at 12:21

1 Answers1

0

This question is that the_model = TheModelClass(*args, **kwargs) means you must define a ModelClass object first. Then you can use the model object to load the disk sequential object. for example:

in_feats = data.x.shape[1]
n_hidden = params["n_hidden"]
n_classes = 2
best_model = OwnGCN(in_c=in_feats, hid_c=n_hidden, out_c=n_classes)
best_model.load_state_dict(torch.load(PATH))
lvjiujin
  • 501
  • 4
  • 5