pytorch实现gru分类
时间: 2023-07-07 20:46:28 浏览: 43
首先,需要导入 PyTorch 和其他必要的库。
``` python
import torch
import torch.nn as nn
import torch.optim as optim
```
接下来,需要定义 GRU 模型的结构。这里使用两层 GRU,每一层都有 128 个隐藏节点。
``` python
class GRU(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), self.hidden_size).to(device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out
```
其中,`input_size` 是输入特征的维度,`hidden_size` 是隐藏层的维度,`output_size` 是输出的类别数。在 `__init__` 方法中,首先调用 `super` 方法初始化父类,并定义了 GRU 层和全连接层。在 `forward` 方法中,首先初始化隐藏状态 `h0`,然后将输入 `x` 和 `h0` 作为 GRU 的输入,得到 GRU 的输出 `out`。最后,将 `out` 的最后一个时间步的输出通过全连接层得到分类结果。
接下来,需要定义训练过程。首先,定义损失函数和优化器。
``` python
model = GRU(input_size, hidden_size, output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
```
然后,进行训练。
``` python
for epoch in range(num_epochs):
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 每训练 10 个 epoch 输出一次损失
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
```
这里使用了 PyTorch 内置的数据加载器 `dataloader`,每次从中加载一个 batch 的数据进行训练。在训练过程中,首先将输入和标签放到 GPU 上,然后将梯度清零,计算输出和损失,反向传播更新参数。
最后,可以使用训练好的模型进行预测。
``` python
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test data: {} %'.format(100 * correct / total))
```
其中,`test_dataloader` 是用于测试的数据加载器。在测试过程中,首先将输入和标签放到 GPU 上,然后计算输出,并根据输出得到预测结果。最后,计算准确率并输出。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)