pytorch实现中文文本分类代码
时间: 2023-12-02 20:00:37 浏览: 160
PyTorch是一个非常流行的深度学习框架,可以用于实现中文文本分类任务。下面是一个简单的示例代码,用于对中文文本进行分类:
首先,我们需要导入必要的库和模块:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext import data
from torchtext.vocab import Vectors
```
接下来,我们定义一个类来构建我们的文本分类模型:
```python
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
super(TextClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
def forward(self, text):
embedded = self.embedding(text)
output, _ = self.rnn(embedded)
hidden = torch.cat((output[-2, :, :], output[-1, :, :]), dim=1)
return self.fc(hidden)
```
然后,我们加载和预处理数据,这里使用了`torchtext`库来进行数据处理:
```python
TEXT = data.Field(tokenize='jieba')
LABEL = data.LabelField()
dataset = data.TabularDataset('data.csv', format='csv', fields=[('text', TEXT), ('label', LABEL)])
train_data, test_data = dataset.split(split_ratio=0.9)
TEXT.build_vocab(train_data, vectors=Vectors('vec.txt'))
LABEL.build_vocab(train_data)
train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data), batch_size=64, shuffle=True)
```
接下来,我们定义模型参数和优化器,并进行训练和评估:
```python
vocab_size = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 256
output_dim = len(LABEL.vocab)
model = TextClassifier(vocab_size, embedding_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(10):
for batch in train_iterator:
text, label = batch.text, batch.label
optimizer.zero_grad()
output = model(text)
loss = criterion(output, label)
loss.backward()
optimizer.step()
model.eval()
correct = 0
total = 0
for batch in test_iterator:
text, label = batch.text, batch.label
output = model(text)
predicted = torch.argmax(output, dim=1)
correct += (predicted == label).sum().item()
total += label.size(0)
accuracy = correct / total
print(f'Accuracy: {accuracy:.4f}')
```
以上就是使用PyTorch实现中文文本分类的基本过程。通过加载和预处理数据,构建模型,并通过训练和评估来对文本进行分类。当然,这只是一个简单的示例代码,你可以根据自己的需求进行调整和扩展。
阅读全文