self.SA_modules.append( nn.Sequential( PointnetSAModuleMSG( npoint=cfg.RPN.SA_CONFIG.NPOINTS[k], radii=cfg.RPN.SA_CONFIG.RADIUS[k], nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k], mlps=mlps[:2], use_xyz=use_xyz, bn=cfg.RPN.USE_BN, ), SelfAttention(channel_out) ) )根据以上代码,我希望我在k等于3时传入SelfAttention的参数为channel_out的一半,而其余情况下传入的参数就是channel_out
时间: 2023-12-03 10:02:12 浏览: 132
浅谈keras通过model.fit_generator训练模型(节省内存)
可以按照如下方式修改代码,实现你的需求:
```
self.SA_modules.append(
nn.Sequential(
PointnetSAModuleMSG(
npoint=cfg.RPN.SA_CONFIG.NPOINTS[k],
radii=cfg.RPN.SA_CONFIG.RADIUS[k],
nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k],
mlps=mlps[:2],
use_xyz=use_xyz,
bn=cfg.RPN.USE_BN,
),
SelfAttention(int(channel_out/2)) if k == 3 else SelfAttention(channel_out)
)
)
```
这里使用了三目运算符,如果 `k == 3`,则传入 `int(channel_out/2)`,否则传入 `channel_out`。
阅读全文