self.SA_modules.append( nn.Sequential( PointnetSAModuleMSG( npoint=cfg.RPN.SA_CONFIG.NPOINTS[k], radii=cfg.RPN.SA_CONFIG.RADIUS[k], nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k], mlps=mlps, use_xyz=use_xyz, bn=cfg.RPN.USE_BN ), SelfAttention(channel_out) ) )这是SA_modules的定义代码块,而 for i in range(len(self.SA_modules)): li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) l_xyz.append(li_xyz) l_features.append(li_features)是SA_modules的调用代码块,运行时程序报错提示: File "/root/autodl-tmp/project/tools/../lib/net/pointnet2_msg.py", line 154, in forward li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given,我想我必须得传入三个参数给SA_modules,第三个参数应该是channel_out,我该如何解决该问题,请给出详细的代码
时间: 2023-11-22 19:52:09 浏览: 151
浅谈keras通过model.fit_generator训练模型(节省内存)
根据错误提示,`forward()` 方法只接受两个位置参数,但是在调用 `self.SA_modules[i]` 方法时,您传递了三个参数。根据您提供的代码,这个错误来自于 `SelfAttention(channel_out)` 这一行。因此,可以尝试在 `SelfAttention` 类的 `forward()` 方法中添加一个额外的参数。
以下是一个可能的解决方案:
```
class SelfAttention(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, 1)
self.conv2 = nn.Conv1d(out_channels, out_channels, 1)
self.conv3 = nn.Conv1d(out_channels, out_channels, 1)
self.softmax = nn.Softmax(dim=-1)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, xyz, features, channel_out):
"""
:param xyz: (B, N, 3)
:param features: (B, C, N)
:param channel_out: (int) number of output channels
:return: updated features
"""
B, C, N = features.size()
xyz = xyz.permute(0, 2, 1) # (B, 3, N)
q = self.conv1(features) # (B, channel_out, N)
k = self.conv2(features) # (B, channel_out, N)
v = self.conv3(features) # (B, channel_out, N)
attn = torch.bmm(k.permute(0, 2, 1), q) # (B, N, N)
attn = self.softmax(attn)
features = torch.bmm(v, attn.permute(0, 2, 1)) # (B, C, N)
features = self.bn(features)
return features
```
然后,在 `PointnetSAModuleMSG` 中,您需要将 `SelfAttention` 的 `forward()` 方法的第三个参数设置为 `channel_out`,如下所示:
```
class PointnetSAModuleMSG(nn.Module):
def __init__(self, npoint, radii, nsamples, mlps, use_xyz=True, bn=True):
super().__init__()
self.npoint = npoint
self.radii = radii
self.nsamples = nsamples
self.use_xyz = use_xyz
self.bn = bn
self.mlp_convs = nn.ModuleList()
self.sa_modules = nn.ModuleList()
last_channel = mlps[-1]
for out_channel in mlps:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
last_channel = out_channel
for i in range(len(radii)):
sa_module = nn.Sequential(
PointnetSAModule(npoint[i], radii[i], nsamples[i], mlps[:-1], use_xyz, bn),
SelfAttention(mlps[-1], mlps[-1]) # <-- add this line
)
self.sa_modules.append(sa_module)
def forward(self, xyz, features):
new_xyz, new_features = sample_and_group_all(xyz, features) # (B, 1, 3), (B, C, 1)
l_xyz, l_features = [new_xyz], [new_features]
for i in range(len(self.radii)):
li_xyz, li_features = sample_and_group_radius(xyz, features, self.npoint[i], self.radii[i], self.nsamples[i], self.use_xyz) # (B, npoint[i], 3), (B, C, npoint[i], nsamples[i])
li_features = li_features.view(li_features.size(0), li_features.size(1), -1) # (B, C, npoint[i] * nsamples[i])
li_features = self.mlp_convs[i](li_features) # (B, mlps[i], npoint[i] * nsamples[i])
li_xyz, li_features = self.sa_modules[i](li_xyz, li_features, mlps[-1]) # <-- add mlps[-1] here
l_xyz.append(li_xyz)
l_features.append(li_features)
return l_xyz, l_features
```
这应该可以解决您遇到的问题。
阅读全文