forward() takes 2 positional arguments but 3 were given

问题描述:

在forward中明明正确数量的参数,却报错:forward() takes 2 positional arguments but 3 were given;

问题分析:

使用nn.Sequential()定义的网络,只接受单输入

例如:

self.backbone=nn.Sequential(nn.lstm(input_size=20, hidden_size=40, num_layers=2),

                                    nn.linear(in_features=40, out_features=2))

def forward(self, input):

        h0 = torch.randn(hidden_layers, batch_size, hidden)

        c0 = torch.randn(hidden_layers, batch_size, hidden)

        output, _ = self.backbone(input)  (对)

         output, _ = self.backbone(input, (h0, c0)   (错误,因为nn.Sequential()定义的网络,只接受单输入)

本文来自网络,不代表协通编程立场,如若转载,请注明出处:https://net2asp.com/d327561128.html