文本分类注意力机制代码
时间: 2024-03-15 07:40:20 浏览: 146
文本分类注意力机制是一种用于提高文本分类性能的技术,它通过对输入文本中的不同部分赋予不同的权重,从而使模型能够更加关注重要的信息。以下是一个简单的文本分类注意力机制的代码示例:
```python
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.att_weights = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.att_bias = nn.Parameter(torch.Tensor(hidden_size))
self.softmax = nn.Softmax(dim=1)
nn.init.uniform_(self.att_weights, -0.1, 0.1)
nn.init.zeros_(self.att_bias)
def forward(self, inputs):
# inputs: [batch_size, seq_len, hidden_size]
batch_size = inputs.size(0)
seq_len = inputs.size(1)
# Calculate attention scores
att_scores = torch.matmul(inputs, self.att_weights) + self.att_bias
att_scores = torch.tanh(att_scores)
# Apply softmax to get attention weights
att_weights = self.softmax(att_scores.view(-1, seq_len)).view(batch_size, seq_len, seq_len)
# Apply attention weights to input sequence
weighted_inputs = torch.matmul(att_weights, inputs)
# Sum the weighted inputs along the sequence length dimension
outputs = torch.sum(weighted_inputs, dim=1)
return outputs
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size, num_classes):
super(TextClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.attention = Attention(hidden_size)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, inputs):
embedded = self.embedding(inputs)
outputs = self.attention(embedded)
logits = self.fc(outputs)
return logits
```
这段代码定义了一个包含注意力机制的文本分类模型。模型的核心是Attention类,它通过计算注意力权重并将其应用于输入序列,得到加权后的表示。TextClassifier类则使用Attention类构建了一个完整的文本分类模型,包括嵌入层、注意力层和全连接层。
阅读全文