#StackBounty: #python #pytorch #layer How to find input layers names for intermediate layer in PyTorch model?

Bounty: 50

I have some complicated model on PyTorch. How can I print names of layers (or IDs) which connected to layer’s input. For start I want to find it for Concat layer. See example code below:

class Concat(nn.Module):
    def __init__(self, dimension=1):
        super().__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)


class SomeModel(nn.Module):
    def __init__(self):
        super(SomeModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.conc = Concat(1)
        self.linear = nn.Linear(8192, 1)

    def forward(self, x):
        out1 = F.relu(self.bn1(self.conv1(x)))
        out2 = F.relu(self.conv2(x))
        out = self.conc([out1, out2])
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


if __name__ == '__main__':
    model = SomeModel()
    print(model)
    y = model(torch.randn(1, 3, 32, 32))
    print(y.size())
    for name, m in model.named_modules():
        if 'Concat' in m.__class__.__name__:
            print(name, m, m.__class__.__name__)
            # Here print names of all input layers for Concat


Get this bounty!!!

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.