循环神经网络 基础案例 pytoch
时间: 2024-09-24 13:26:58 浏览: 10
循环神经网络 (Recurrent Neural Network, RNN) 是一种深度学习模型,特别适合处理序列数据,如文本、语音或时间序列数据。它通过引入循环连接结构,使得网络能够在处理每个输入时保留之前的信息,这种特性被称为“记忆”。
基础的PyTorch实现RNN的一个例子通常会涉及LSTM (Long Short-Term Memory) 或 GRU (Gated Recurrent Unit),这两种是改进的RNN,能更好地解决长期依赖问题。下面是一个简单的LSTM基本案例:
```python
import torch
from torch import nn
# 定义一个简单的LSTM模型
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, output_size=1):
super(SimpleRNN, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x: [batch_size, sequence_length, input_size]
out, _ = self.rnn(x)
# out: [batch_size, sequence_length, hidden_size]
# 我们通常只取最后一个时间步的隐藏状态作为输出
out = self.fc(out[:, -1, :]) # shape: [batch_size, output_size]
return out
# 使用示例
input_size = 10
hidden_size = 20
model = SimpleRNN(input_size, hidden_size)
input_data = torch.randn(64, 50, input_size) # 假设我们有64个样本,每个样本50个时间步,每个时间步10维输入
output = model(input_data)
```
在这个案例中,`input_data`是序列数据,模型会对每个时间步应用LSTM层,并在最后通过全连接层得到输出。