这是类的定义,class SelfAttention(nn.Module): def __init__(self, in_channels, reduction=4): super(SelfAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool1d(1) self.fc1 = nn.Conv1d(in_channels, in_channels // reduction, 1, bias=False) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv1d(in_channels // reduction, in_channels, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): b, c, n = x.size() y = self.avg_pool(x) y = self.fc1(y) y = self.relu(y) y = self.fc2(y) y = self.sigmoid(y) return x * y.expand_as(x),而这是使用该类的位置:self.SA_modules.append( nn.Sequential( PointnetSAModuleMSG( npoint=cfg.RPN.SA_CONFIG.NPOINTS[k], radii=cfg.RPN.SA_CONFIG.RADIUS[k], nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k], mlps=mlps, use_xyz=use_xyz, bn=cfg.RPN.USE_BN, channel_out=channel_out ), SelfAttention(channel_out=channel_out) ) ),但是运行时报错:File "/root/autodl-tmp/project/tools/../lib/net/pointnet2_msg.py", line 125, in __init__ SelfAttention(channel_out=channel_out) TypeError: __init__() got an unexpected keyword argument 'channel_out',为什么呢?如何解决
时间: 2023-07-15 09:12:05 浏览: 56
在你的代码中,SelfAttention类的初始化方法中只接收两个参数in_channels和reduction,而你在使用该类时传递了一个名为channel_out的参数,导致了TypeError。
解决方法是修改SelfAttention类的定义,为其增加一个名为channel_out的参数,并在初始化方法中接收和使用该参数。修改后的类定义如下:
```
class SelfAttention(nn.Module):
def __init__(self, in_channels, channel_out, reduction=4):
super(SelfAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Conv1d(in_channels, in_channels // reduction, 1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv1d(in_channels // reduction, channel_out, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, n = x.size()
y = self.avg_pool(x)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
return x * y.expand_as(x)
```
然后在使用SelfAttention类时,需要传递in_channels和channel_out参数,例如:
```
self.SA_modules.append(
nn.Sequential(
PointnetSAModuleMSG(
npoint=cfg.RPN.SA_CONFIG.NPOINTS[k],
radii=cfg.RPN.SA_CONFIG.RADIUS[k],
nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=use_xyz,
bn=cfg.RPN.USE_BN,
channel_out=channel_out
),
SelfAttention(in_channels=channel_out, channel_out=channel_out)
)
)
```