textcnn文本分类代码
时间: 2023-08-23 07:08:21 浏览: 92
以下是一个基于TextCNN的文本分类代码示例,使用PyTorch框架:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class TextCNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.conv_layers = nn.ModuleList([
nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim))
for fs in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * num_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
# text = [batch size, sent len]
text = text.permute(1, 0) # text = [sent len, batch size]
embedded = self.embedding(text) # embedded = [sent len, batch size, emb dim]
embedded = embedded.unsqueeze(1) # embedded = [batch size, 1, sent len, emb dim]
conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.conv_layers] # conved = [batch size, num filters, sent len - filter_sizes[n] + 1]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved] # pooled = [batch size, num filters]
cat = self.dropout(torch.cat(pooled, dim=1)) # cat = [batch size, num filters * len(filter_sizes)]
output = self.fc(cat) # output = [batch size, output dim]
return output
```
这里定义了一个名为TextCNN的类,它继承自nn.Module类,表示这是一个神经网络模型。在init方法中,定义了模型中要使用的各个层,包括Embedding层、多个卷积层、全连接层和Dropout层。在forward方法中,将输入的文本数据text先通过Embedding层将其转换为词向量,然后将词向量转换为二维张量,并输入到多个卷积层中进行卷积操作。最后,将卷积结果通过max_pooling操作,将每个卷积核的输出压缩为一个标量,然后将多个卷积核的输出结果拼接在一起,输入到全连接层中进行分类。最后,将全连接层的输出结果返回。
文本分类模型的训练过程可以参考以下代码:
```python
model = TextCNN(len(vocab), embedding_dim, num_filters, filter_sizes, output_dim, dropout)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(num_epochs):
for batch in train_iterator:
optimizer.zero_grad()
text, labels = batch.text.to(device), batch.label.to(device)
output = model(text)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
correct = 0
total = 0
for batch in test_iterator:
text, labels = batch.text.to(device), batch.label.to(device)
output = model(text)
predictions = torch.argmax(output, dim=1)
correct += (predictions == labels).sum().item()
total += labels.shape[0]
accuracy = correct / total
print(f'Test accuracy: {accuracy:.3f}')
```
首先,定义了一个TextCNN对象作为文本分类的模型。然后,定义了一个Adam优化器和交叉熵损失函数。接着,使用train_iterator遍历训练集数据,将每个batch的文本数据text和标签数据labels转移到GPU上,并使用模型计算输出结果output。然后,计算输出结果output和真实标签数据labels之间的交叉熵损失,并进行反向传播和模型参数优化。最后,使用test_iterator遍历测试集数据,计算模型在测试集上的准确率。
阅读全文