class DropBlock_Ske(nn.Module): def __init__(self, num_point, block_size=7): super(DropBlock_Ske, self).__init__() self.keep_prob = 0.0 self.block_size = block_size self.num_point = num_point self.fc_1 = nn.Sequential( nn.Linear(in_features=25, out_features=25, bias=True), nn.ReLU(inplace=True), nn.Linear(in_features=25, out_features=25, bias=True), ) self.fc_2 = nn.Sequential( nn.Linear(in_features=25, out_features=25, bias=True), nn.ReLU(inplace=True), nn.Linear(in_features=25, out_features=25, bias=True), ) self.sigmoid = nn.Sigmoid() def forward(self, input, keep_prob, A): # n,c,t,v self.keep_prob = keep_prob if not self.training or self.keep_prob == 1: return input n, c, t, v = input.size() input_attention_mean = torch.mean(torch.mean(input, dim=2), dim=1).detach() # 32 25 input_attention_max = torch.max(input, dim=2)[0].detach() input_attention_max = torch.max(input_attention_max, dim=1)[0] # 32 25 avg_out = self.fc_1(input_attention_mean) max_out = self.fc_2(input_attention_max) out = avg_out + max_out input_attention_out = self.sigmoid(out).view(n, 1, 1, self.num_point) input_a = input * input_attention_out input_abs = torch.mean(torch.mean( torch.abs(input_a), dim=2), dim=1).detach() input_abs = input_abs / torch.sum(input_abs) * input_abs.numel() gamma = 0.024 M_seed = torch.bernoulli(torch.clamp( input_abs * gamma, min=0, max=1.0)).to(device=input.device, dtype=input.dtype) M = torch.matmul(M_seed, A) M[M > 0.001] = 1.0 M[M < 0.5] = 0.0 mask = (1 - M).view(n, 1, 1, self.num_point) return input * mask * mask.numel() / mask.sum()
时间: 2024-04-19 22:27:40 浏览: 124
这段代码定义了一个名为DropBlock_Ske的PyTorch模型类。该类继承自nn.Module基类,并实现了模型的前向传播逻辑。
在初始化方法中,定义了一些模型的属性,包括keep_prob、block_size、num_point等。然后,定义了两个全连接网络层fc_1和fc_2,并使用nn.Sequential组织网络层的结构。最后,使用nn.Sigmoid定义了一个Sigmoid激活函数。
在前向传播方法中,首先根据输入的keep_prob值判断是否需要执行DropBlock操作,如果不需要则直接返回输入。接着,获取输入张量的形状,并计算出输入张量的均值和最大值。通过两个全连接网络层对均值和最大值进行处理,将处理结果相加得到out。然后,使用Sigmoid激活函数对out进行处理,得到一个形状为(n, 1, 1, num_point)的张量input_attention_out。将input_attention_out与输入张量input相乘得到input_a。
接下来,计算input_a的绝对值的平均值,并将其除以总数并乘以元素个数,得到一个形状为(n,)的张量input_abs。定义了一个gamma值,并将input_abs与gamma相乘并经过torch.clamp函数进行限制,再经过torch.bernoulli函数进行伯努利采样,得到一个形状与输入相同的二值张量M_seed。使用torch.matmul函数将M_seed与A矩阵相乘得到M。
然后,将M中大于0.001的元素赋值为1.0,小于0.5的元素赋值为0.0。接着,将1减去M得到mask,将mask乘以输入张量input,并除以mask中的元素个数与总和,得到最终的输出张量。
这个模型类实现了DropBlock_Ske操作,其中包含了一些全连接网络层和一些基于概率的操作。它的具体功能和用途可能需要根据上下文来确定。
阅读全文
相关推荐
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![application/x-rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![application/x-rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![application/x-rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![application/x-rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![application/x-rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)