Attention-based Pooling 代码
时间: 2023-07-09 19:33:51 浏览: 128
下面是一个使用PyTorch实现Attention-based Pooling的示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionPooling(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(AttentionPooling, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.W = nn.Linear(input_dim, hidden_dim)
self.U = nn.Linear(hidden_dim, 1)
def forward(self, inputs):
scores = self.U(torch.tanh(self.W(inputs)))
weights = F.softmax(scores, dim=0)
pooled = torch.sum(inputs * weights, dim=0)
return pooled
# Example usage
input_dim = 100
hidden_dim = 50
seq_len = 10
batch_size = 32
inputs = torch.randn(seq_len, batch_size, input_dim)
attention_pooling = AttentionPooling(input_dim, hidden_dim)
output = attention_pooling(inputs)
print(output.shape) # (batch_size, hidden_dim)
```
在这个例子中,我们定义了一个名为AttentionPooling的类,该类包含了Attention-based Pooling所需的所有逻辑。实例化AttentionPooling类后,我们可以将输入数据传递给forward()方法进行计算。最后,我们输出池化后的表示形状,其形状为(batch_size, hidden_dim)。
阅读全文