用pytorch实现基于注意力机制的特征融合
时间: 2024-05-01 15:17:04 浏览: 226
以下是基于注意力机制的特征融合的PyTorch实现示例:
```python
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, input_size):
super(Attention, self).__init__()
self.input_size = input_size
self.linear = nn.Linear(input_size, input_size)
self.softmax = nn.Softmax(dim=1)
def forward(self, input):
# input shape: (batch_size, seq_len, input_size)
energy = self.linear(input)
# energy shape: (batch_size, seq_len, input_size)
energy = torch.tanh(energy)
# energy shape: (batch_size, seq_len, input_size)
attention = self.softmax(energy.sum(dim=2, keepdim=True))
# attention shape: (batch_size, seq_len, 1)
weighted_input = torch.mul(input, attention)
# weighted_input shape: (batch_size, seq_len, input_size)
weighted_sum = weighted_input.sum(dim=1)
# weighted_sum shape: (batch_size, input_size)
return weighted_sum
class FeatureFusion(nn.Module):
def __init__(self, input_size1, input_size2, hidden_size):
super(FeatureFusion, self).__init__()
self.linear1 = nn.Linear(input_size1, hidden_size)
self.linear2 = nn.Linear(input_size2, hidden_size)
self.attention = Attention(hidden_size)
def forward(self, input1, input2):
# input1 shape: (batch_size, seq_len1, input_size1)
# input2 shape: (batch_size, seq_len2, input_size2)
hidden1 = self.linear1(input1)
# hidden1 shape: (batch_size, seq_len1, hidden_size)
hidden2 = self.linear2(input2)
# hidden2 shape: (batch_size, seq_len2, hidden_size)
fused_hidden = torch.cat((hidden1, hidden2), dim=1)
# fused_hidden shape: (batch_size, seq_len1 + seq_len2, hidden_size)
fused_hidden = self.attention(fused_hidden)
# fused_hidden shape: (batch_size, hidden_size)
return fused_hidden
```
上述代码中,Attention类实现了注意力机制,FeatureFusion类实现了两个特征的融合。在Attention类中,首先通过一个线性层将输入的特征映射到一个新的空间,然后通过tanh函数将其压缩到[-1, 1]的范围内,再通过softmax函数计算每个特征的权重,最后将输入的特征和权重相乘得到加权和。在FeatureFusion类中,首先将两个特征通过线性层映射到相同的维度,然后将它们拼接在一起,再通过Attention类得到融合后的特征。
阅读全文