I am trying to understand pytorch jit usage. So I have my demo Conv Block
class convBlock(nn.Module):
def __init__(self):
super(convBlock, self).__init__()
self.conv = nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.batch = nn.BatchNorm1d(64)
self.relu = nn.ReLU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1, return_indices=True)
def forward(self, input_1D):
input_1D = self.conv(input_1D)
input_1D = self.relu(self.batch(input_1D))
input_1D, indx_mat = self.maxPool(input_1D)
return input_1D, indx_mat
and my Deconv block:
class deconvBlock(nn.Module):
def __init__(self):
super(deconvBlock, self).__init__()
self.deconv = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.batchNorm = nn.BatchNorm1d(32)
self.relu = nn.ReLU()
self.unpool = nn.MaxUnpool1d(kernel_size=3, stride=1, padding=0)
def forward(self, input_1D, idmat):
input_1D = self.unpool(input_1D, idmat)
input_1D = self.deconv(input_1D)
# batch norm
input_1D = self.batchNorm(input_1D)
# relu
input_1D = self.relu(input_1D)
return input_1D
and thus my full model is
class modelFull(nn.Module):
def __init__(self):
super(modelFull, self).__init__()
self.bll = convBlock()
self.deconv = deconvBlock()
def forward(self, x):
xx, y = self.bll(x)
xz = self.deconv(xx, y)
return xz
Now I tried
torch.jit.script(modelFull())
but it is not scripted and throws error.
RuntimeError Traceback (most recent call last)
Cell In[17], line 1
----> 1 torch.jit.script(modelFull())
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_script.py:1284, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1282 if isinstance(obj, torch.nn.Module):
1283 obj = call_prepare_scriptable_func(obj)
-> 1284 return torch.jit._recursive.create_script_module(
1285 obj, torch.jit._recursive.infer_methods_to_compile
1286 )
1288 if isinstance(obj, dict):
1289 return create_script_dict(obj)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:480, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
478 if not is_tracing:
479 AttributeTypeIsSupportedChecker().check(nn_module)
--> 480 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
539 script_module._concrete_type = concrete_type
541 # Actually create the ScriptModule, initializing it with the function we just defined
--> 542 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
...
394 property_defs = [p.def_ for p in property_stubs]
395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.pooling.MaxUnpool1d (of Python compilation unit at: 0x5e93700)
UPDATE I removed the Unpool layer so I got this error.
RuntimeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 torch.jit.script(Spice_model2())
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_script.py:1284, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1282 if isinstance(obj, torch.nn.Module):
1283 obj = call_prepare_scriptable_func(obj)
-> 1284 return torch.jit._recursive.create_script_module(
1285 obj, torch.jit._recursive.infer_methods_to_compile
1286 )
1288 if isinstance(obj, dict):
1289 return create_script_dict(obj)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:480, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
478 if not is_tracing:
479 AttributeTypeIsSupportedChecker().check(nn_module)
--> 480 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
539 script_module._concrete_type = concrete_type
541 # Actually create the ScriptModule, initializing it with the function we just defined
--> 542 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_script.py:614, in RecursiveScriptModule._construct(cpp_module, init_fn)
601 """
602 Construct a RecursiveScriptModule that's ready for use. PyTorch
603 code should use this to construct a RecursiveScriptModule instead
(...)
611 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
612 """
613 script_module = RecursiveScriptModule(cpp_module)
--> 614 init_fn(script_module)
616 # Finalize the ScriptModule: replace the nn.Module state with our
617 # custom implementations and flip the _initializing bit.
618 RecursiveScriptModule._finalize_scriptmodule(script_module)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:520, in create_script_module_impl..init_fn(script_module)
517 scripted = orig_value
518 else:
519 # always reuse the provided stubs_fn to infer the methods to compile
--> 520 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
522 cpp_module.setattr(name, scripted)
523 script_module._modules[name] = scripted
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
539 script_module._concrete_type = concrete_type
541 # Actually create the ScriptModule, initializing it with the function we just defined
--> 542 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_script.py:614, in RecursiveScriptModule._construct(cpp_module, init_fn)
601 """
602 Construct a RecursiveScriptModule that's ready for use. PyTorch
603 code should use this to construct a RecursiveScriptModule instead
(...)
611 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
612 """
613 script_module = RecursiveScriptModule(cpp_module)
--> 614 init_fn(script_module)
616 # Finalize the ScriptModule: replace the nn.Module state with our
617 # custom implementations and flip the _initializing bit.
618 RecursiveScriptModule._finalize_scriptmodule(script_module)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:520, in create_script_module_impl..init_fn(script_module)
517 scripted = orig_value
518 else:
519 # always reuse the provided stubs_fn to infer the methods to compile
--> 520 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
522 cpp_module.setattr(name, scripted)
523 script_module._modules[name] = scripted
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
539 script_module._concrete_type = concrete_type
541 # Actually create the ScriptModule, initializing it with the function we just defined
--> 542 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_script.py:614, in RecursiveScriptModule._construct(cpp_module, init_fn)
601 """
602 Construct a RecursiveScriptModule that's ready for use. PyTorch
603 code should use this to construct a RecursiveScriptModule instead
(...)
611 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
612 """
613 script_module = RecursiveScriptModule(cpp_module)
--> 614 init_fn(script_module)
616 # Finalize the ScriptModule: replace the nn.Module state with our
617 # custom implementations and flip the _initializing bit.
618 RecursiveScriptModule._finalize_scriptmodule(script_module)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:520, in create_script_module_impl..init_fn(script_module)
517 scripted = orig_value
518 else:
519 # always reuse the provided stubs_fn to infer the methods to compile
--> 520 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
522 cpp_module.setattr(name, scripted)
523 script_module._modules[name] = scripted
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:546, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
--> 546 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
547 # Create hooks after methods to ensure no name collisions between hooks and methods.
548 # If done before, hooks can overshadow methods that aren't exported.
549 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File ~/anaconda3/envs/mpa/lib/python3.10/site-packages/torch/jit/_recursive.py:397, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
394 property_defs = [p.def_ for p in property_stubs]
395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: isIntList() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/core/ivalue_inl.h":1938, please report a bug to PyTorch. Expected IntList but got Int
Please help how to solve this as I am new in pytorch