1

I have two files x.py and y.py.

Inside y.py, there are two classes, A and B. Class A calls class B inside run function.

In file x.py, I imported class A to run it:

from y import A
obj = A()
obj.run()

I got this error: AttributeError: Can't get attribute 'B' on <module '__main__' from 'x.py'>

I found actually the solution here: AttributeError: Can't get attribute on <module 'main' from 'manage.py'>, which simple solves this error by importing class B inside file x.py:

from y import A, B

I feel this is not the right way to do it if I want to build a python library. I don't think you need always to do that when you call a class from a library.

How can I solve this issue in the "right way" from a software engineer perspective?

Edit: adding an example

Class B is RoBerta_CLS.

file x.py:

# from y import RoBerta_CLS
from y import A

if __name__ == '__main__':
    obj = A(model_path='/home/PATH/models/DistilRoBERTa/')

file y.py:

import torch, os
import torch.nn as nn
from transformers import RobertaForSequenceClassification
device = 'cuda'

class RoBerta_CLS(torch.nn.Module):

    def __init__(self, model_params):
        super().__init__()
        self.encoder = RobertaForSequenceClassification.from_pretrained(model_params['MODEL'], num_labels=1)
        self.encoder = self.encoder.to(device)

    def save_pretrained(self, output_model_file):
        torch.save(self, output_model_file + 'pytorch_model.pt')
        print('saved..')
    
    @staticmethod
    def from_pretrained(output_model_file):
        model = torch.load(os.path.join(output_model_file, 'pytorch_model.pt'))
        print('loaded..')
        return model


class A:
    
    def __init__(self, model_path=''):
        self.model = RoBerta_CLS.from_pretrained(model_path)
Barmar
  • 741,623
  • 53
  • 500
  • 612
Minions
  • 5,104
  • 5
  • 50
  • 91

1 Answers1

0

That's very interesting. I don't have the same issue here. Could you post some code that's giving you the issue?

Update #1:

Here's mine:

x.py

# from y import RoBerta_CLS
from y import A

if __name__ == '__main__':
    obj = A(model_path='/home/PATH/models/DistilRoBERTa/')
    print("if you see this, it means the program ran without errors")

y.py

device = 'cuda'

class RoBerta_CLS(object):

    def __init__(self, model_params):
        pass

    def save_pretrained(self, output_model_file):
        pass
    
    @staticmethod
    def from_pretrained(output_model_file):
        pass

class A:
    
    def __init__(self, model_path=''):
        self.model = RoBerta_CLS.from_pretrained(model_path)
python3 x.py
if you see this, it means the program ran without errors
pandascope
  • 117
  • 2
  • 2
  • 10
  • Maybe I have another issue. I just added an example. – Minions Feb 18 '22 at 23:48
  • updated my code as well using your example without pytorch, and worked. i believe its a pytorch-related issue and not python – pandascope Feb 18 '22 at 23:57
  • I found the error, it's because of this line: `model = torch.load(os.path.join(output_model_file, 'pytorch_model.pt'))`, but it's not clear why! – Minions Feb 18 '22 at 23:59
  • nice, at least we're getting somewhere. does `os.path.join(output_model_file, 'pytorch_model.pt')` exist? – pandascope Feb 19 '22 at 00:02
  • yes, it works when I uncomment this line in `x.py` : `from y import RoBerta_CLS` – Minions Feb 19 '22 at 00:03