GRU pytorch 文本分类
时间: 2023-10-21 12:28:48 浏览: 145
GRU(Gated Recurrent Unit)是一种循环神经网络(RNN)的变体,用于处理序列数据。在PyTorch中,你可以使用`torch.nn.GRU`类来创建GRU模型进行文本分类。
下面是一个简单的示例代码,演示如何使用GRU进行文本分类:
```python
import torch
import torch.nn as nn
class GRUClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(GRUClassifier, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device) # 初始化隐藏状态
out, _ = self.gru(x, h0) # 前向传播
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
return out
# 定义模型参数
input_size = 100 # 输入特征大小
hidden_size = 128 # 隐藏层大小
num_classes = 2 # 分类类别数
# 创建一个GRU分类器模型
model = GRUClassifier(input_size, hidden_size, num_classes)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 进行模型训练
for epoch in range(num_epochs):
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 使用训练好的模型进行预测
predicted = model(inputs)
```
以上代码中,我们首先定义了一个名为`GRUClassifier`的自定义模型类,该类继承自`nn.Module`。在模型的构造函数中,我们定义了GRU层和全连接层。在前向传播函数中,我们将输入数据传递给GRU层,然后取最后一个时间步的输出作为分类器的输入,并返回预测结果。
你可以根据自己的数据集和任务需求调整模型的各个参数。希望这个示例能对你有所帮助!如果有任何问题,请随时提问。
阅读全文