利用pytorch写一个self-attention的二分类代码
时间: 2023-05-12 19:06:16 浏览: 268
基于attention+ gru文本分类代码
5星 · 资源好评率100%
好的,以下是一个利用 PyTorch 实现 self-attention 的二分类代码:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.projection = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(True),
nn.Linear(64, 1)
)
def forward(self, encoder_outputs):
# encoder_outputs: [batch_size, sequence_length, hidden_size]
energy = self.projection(encoder_outputs)
weights = nn.functional.softmax(energy.squeeze(-1), dim=1)
outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)
return outputs
class Classifier(nn.Module):
def __init__(self, hidden_size):
super(Classifier, self).__init__()
self.attention = SelfAttention(hidden_size)
self.fc = nn.Linear(hidden_size, 2)
def forward(self, inputs):
# inputs: [batch_size, sequence_length, hidden_size]
outputs = self.attention(inputs)
outputs = self.fc(outputs)
return outputs
# Example usage
model = Classifier(hidden_size=256)
inputs = torch.randn(32, 10, 256)
outputs = model(inputs)
print(outputs.shape) # torch.Size([32, 2])
```
这个代码实现了一个基于 self-attention 的二分类器,其中 `SelfAttention` 类实现了 self-attention 的计算,`Classifier` 类则将 self-attention 的输出通过一个全连接层映射到二分类的结果。
阅读全文