self.groupers.append( pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) if npoint is not None else pointnet2_utils.GroupAll(use_xyz) )这是self.groupers的定义语句块,那self.groupers[0]是什么样子的
时间: 2023-11-29 14:05:23 浏览: 22
假设此前已经执行了类似于以下代码的语句:
```
self.groupers = nn.ModuleList()
```
那么,在执行完如下代码后:
```
self.groupers.append(pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) if npoint is not None else pointnet2_utils.GroupAll(use_xyz))
```
`self.groupers[0]` 就会变成一个 `QueryAndGroup` 或 `GroupAll` 类型的对象,具体取决于 `npoint` 是否为 `None`。其中 `radius` 和 `nsample` 是在 `QueryAndGroup` 初始化时传入的参数,`use_xyz` 在两种情况下都会传入。
相关问题
class PointnetSAModuleMSG(_PointnetSAModuleBase): """Pointnet set abstraction layer with multiscale grouping""" def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): """ :param npoint: int :param radii: list of float, list of radii to group with :param nsamples: list of int, number of samples in each ball query :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale :param bn: whether to use batchnorm :param use_xyz: :param pool_method: max_pool / avg_pool :param instance_norm: whether to use instance_norm """ super().__init__() assert len(radii) == len(nsamples) == len(mlps) self.npoint = npoint self.groupers = nn.ModuleList() self.mlps = nn.ModuleList() for i in range(len(radii)): radius = radii[i] nsample = nsamples[i] self.groupers.append( pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) if npoint is not None else pointnet2_utils.GroupAll(use_xyz) ) mlp_spec = mlps[i] if use_xyz: mlp_spec[0] += 3 self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) self.pool_method = pool_method我想将以上模块输出的张量的shape在模块尾部打印出来方便检查,你可以提供详细的代码吗?
当模块输出的张量维度较多时,可以使用 `torch.Size` 或者 `tuple` 将它们打印出来。以下是添加打印输出的代码示例:
```
class PointnetSAModuleMSG(_PointnetSAModuleBase):
"""Pointnet set abstraction layer with multiscale grouping"""
def __init__(
self,
*,
npoint: int,
radii: List[float],
nsamples: List[int],
mlps: List[List[int]],
bn: bool = True,
use_xyz: bool = True,
pool_method='max_pool',
instance_norm=False
):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
:param instance_norm: whether to use instance_norm
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
self.pool_method = pool_method
def forward(self, xyz, features=None):
"""
:param xyz: (batch_size, num_points, 3) tensor
:param features: (batch_size, num_points, dim) tensor, optional
:return:
new_xyz: (batch_size, npoint, 3) tensor
new_features: (batch_size, npoint, \sum_k(mlps[k][-1])) tensor
"""
B, N, C = xyz.shape
if self.npoint is not None:
fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
new_xyz = pointnet2_utils.gather_operation(xyz.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
else:
new_xyz = None
new_features_list = []
for i, grouper in enumerate(self.groupers):
grouped_xyz, grouped_features, idx = grouper(new_xyz, xyz, features)
new_features = self.mlps[i](grouped_features)
if self.pool_method == 'max_pool':
new_features = F.max_pool1d(new_features, new_features.size(2)).squeeze(2) # (B, C, N) -> (B, C)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool1d(new_features, new_features.size(2)).squeeze(2) # (B, C, N) -> (B, C)
new_features_list.append(new_features)
new_features = torch.cat(new_features_list, dim=1)
# 打印输出张量的形状
print("new_xyz shape:", new_xyz.shape)
print("new_features shape:", new_features.shape)
return new_xyz, new_features
```
在 `forward` 方法中添加了打印输出,可以直接输出张量的形状信息。
self.SA_modules.append( nn.Sequential( PointnetSAModuleMSG( npoint=cfg.RPN.SA_CONFIG.NPOINTS[k], radii=cfg.RPN.SA_CONFIG.RADIUS[k], nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k], mlps=mlps, use_xyz=use_xyz, bn=cfg.RPN.USE_BN ), SelfAttention(channel_out) ) )这是SA_modules的定义代码块,而 for i in range(len(self.SA_modules)): li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) l_xyz.append(li_xyz) l_features.append(li_features)是SA_modules的调用代码块,而这是PointnetSAModuleMSG类的父类的代码:class _PointnetSAModuleBase(nn.Module): def init(self): super().init() self.npoint = None self.groupers = None self.mlps = None self.pool_method = 'max_pool' def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): """ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features :param features: (B, N, C) tensor of the descriptors of the the features :param new_xyz: :return: new_xyz: (B, npoint, 3) tensor of the new features' xyz new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if new_xyz is None: new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) ).transpose(1, 2).contiguous() if self.npoint is not None else None for i in range(len(self.groupers)): new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlpsi # (B, mlp[-1], npoint, nsample) if self.pool_method == 'max_pool': new_features = F.max_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) elif self.pool_method == 'avg_pool': new_features = F.avg_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) else: raise NotImplementedError new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1);运行时程序报错提示我在调用SA_modules时传递的三个参数,现在看来应该是多出了参数channel_out,我该怎么修改代码才能让SA_modules顺利接受三个参数并正常运行
首先,你需要在PointnetSAModuleMSG类中增加一个名为channel_out的参数,并在init函数中对其进行初始化:
```
class PointnetSAModuleMSG(_PointnetSAModuleBase):
def __init__(self, npoint, radii, nsamples, mlps, use_xyz=True, bn=True, channel_out=None):
super().__init__()
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
self.use_xyz = use_xyz
self.channel_out = channel_out
for i in range(len(mlps)):
if channel_out is not None and i == len(mlps)-1:
mlps[i][-1] = channel_out
self.mlps.append(pt_utils.SharedMLP(mlps[i], bn=bn))
radius_list = np.array(radii)
nsample_list = np.array(nsamples)
for i in range(len(radius_list)):
self.groupers.append(
pt_utils.QueryAndGroup(radius_list[i], nsample_list[i], use_xyz=use_xyz)
)
```
然后,在SA_modules的定义代码块中,你需要修改SelfAttention的调用方式,将channel_out作为参数传递给SelfAttention:
```
self.SA_modules.append(
nn.Sequential(
PointnetSAModuleMSG(
npoint=cfg.RPN.SA_CONFIG.NPOINTS[k],
radii=cfg.RPN.SA_CONFIG.RADIUS[k],
nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=use_xyz,
bn=cfg.RPN.USE_BN,
channel_out=channel_out
),
SelfAttention(channel_out=channel_out)
)
)
```
最后,在调用SA_modules的代码块中,你需要将channel_out传递给SA_modules:
```
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i], channel_out=channel_out)
l_xyz.append(li_xyz)
l_features.append(li_features)
```
这样就可以让SA_modules顺利接受三个参数并正常运行。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![java](https://img-home.csdnimg.cn/images/20210720083646.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)