基于pytorch对给定表格数据利用LSTM模型进行时间序列分类以后,对数据进行预测分类的代码
时间: 2024-05-01 10:23:25 浏览: 148
基于pytorch搭建多特征LSTM时间序列预测.zip
5星 · 资源好评率100%
以下是一个基于PyTorch的LSTM模型,用于时间序列分类和预测的代码。请注意,此代码仅供参考,您需要根据自己的数据和问题进行适当的修改和调整。
```python
import torch
import torch.nn as nn
import numpy as np
# 定义LSTM模型
class LSTMClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(LSTMClassifier, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
lstm_out, _ = self.lstm(x)
out = self.fc(lstm_out[:, -1, :])
return out
# 定义训练函数
def train(model, X_train, y_train, X_val, y_val, num_epochs, learning_rate):
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_loss = []
val_loss = []
for epoch in range(num_epochs):
# 训练模型
model.train()
optimizer.zero_grad()
output = model(X_train)
loss = criterion(output, y_train)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
# 验证模型
model.eval()
output = model(X_val)
loss = criterion(output, y_val)
val_loss.append(loss.item())
# 打印训练过程
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss[-1]:.4f}, Val Loss: {val_loss[-1]:.4f}')
return model, train_loss, val_loss
# 定义预测函数
def predict(model, X_test):
model.eval()
with torch.no_grad():
output = model(X_test)
_, predicted = torch.max(output.data, 1)
return predicted
# 生成示例数据
X_train = np.random.rand(100, 10, 5)
y_train = np.random.randint(0, 5, size=(100,))
X_val = np.random.rand(20, 10, 5)
y_val = np.random.randint(0, 5, size=(20,))
X_test = np.random.rand(10, 10, 5)
# 定义模型参数
input_dim = X_train.shape[2]
hidden_dim = 32
output_dim = len(np.unique(y_train))
# 训练模型并进行预测
model = LSTMClassifier(input_dim, hidden_dim, output_dim)
model, train_loss, val_loss = train(model, torch.FloatTensor(X_train), torch.LongTensor(y_train),
torch.FloatTensor(X_val), torch.LongTensor(y_val), 50, 0.01)
predicted = predict(model, torch.FloatTensor(X_test))
print(predicted)
```
在上面的代码中,我们定义了一个LSTMClassifier类作为我们的模型,并使用CrossEntropyLoss作为损失函数,Adam优化器作为优化器。我们还定义了train函数用于训练模型,并返回训练后的模型和训练/验证损失。最后,我们定义了predict函数用于使用训练好的模型进行预测。
在生成示例数据后,我们传递数据和模型参数来训练模型并进行预测。请注意,我们使用了PyTorch张量而不是NumPy数组来处理数据。
阅读全文