0

I'm writing a function that computes the sparsity of the weight matrices of the following fully connected network:

class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()

        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):

        out = self.fc1(x)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.relu3(out)
        out = self.fc4(out)

        return out

The function I have written is the following:

def print_layer_sparsity(model):
    for name,module in model.named_modules():
        if 'fc' in name:
            zeros = 100. * float(torch.sum(model.name.weight == 0))
            tot = float(model.name.weight.nelement())
            print("Sparsity in {}.weight: {:.2f}%".format(name, zeros/tot))

But it gives me the following error:

torch.nn.modules.module.ModuleAttributeError: 'FCN' object has no attribute 'name'

It works fine when I manually enter the name of the layers (e.g.,

(model.fc1.weight == 0) (model.fc2.weight == 0) (model.fc3.weight == 0) ....

but I'd like to make it independent from the network. In other words, I'd like to adapt my function in a way that, given any sparse network, it prints the sparsity of every layer. Any suggestions?

Thanks!!

Alfred
  • 503
  • 1
  • 5
  • 20

1 Answers1

1

Try:

getattr(model, name).weight

In place of

model.name.weight

Your print_layer_sparsity function becomes:

def print_layer_sparsity(model):
    for name,module in model.named_modules():
        if 'fc' in name:
            zeros = 100. * float(torch.sum(getattr(model, name).weight == 0))
            tot = float(getattr(model, name).weight.nelement())
            print("Sparsity in {}.weight: {:.2f}%".format(name, zeros/tot))

You can't do model.name because name is a str. The in-built getattr function allows you to get the member variables / attributes of an object using its name as a string.

For more information, checkout this answer.