pytorch简单构建一个LSTM,读取CSV文件夹形式的数据集
时间: 2024-04-08 16:31:24 浏览: 19
可以使用PyTorch中的torchtext库来读取CSV文件夹形式的数据集并构建一个简单的LSTM模型。以下是一个示例代码:
```python
import torch
import torch.nn as nn
from torchtext.data import Field, TabularDataset, Iterator
# 定义字段(特征)的处理方式
TEXT = Field(sequential=True, lower=True, tokenize='spacy')
LABEL = Field(sequential=False, use_vocab=False)
# 从CSV文件夹中读取数据集
train_data, test_data = TabularDataset.splits(
path='path_to_csv_folder',
train='train.csv', test='test.csv',
format='csv',
fields=[('text', TEXT), ('label', LABEL)]
)
# 构建词汇表
TEXT.build_vocab(train_data)
# 创建迭代器
train_iterator, test_iterator = Iterator.splits(
(train_data, test_data),
batch_size=64,
sort_key=lambda x: len(x.text),
sort_within_batch=False
)
# 定义LSTM模型
class LSTMModel(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, text):
embedded = self.embedding(text)
output, (hidden, _) = self.lstm(embedded)
hidden = hidden.squeeze(0)
return self.fc(hidden)
# 初始化模型和损失函数
input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 256
output_dim = 2
model = LSTMModel(input_dim, embedding_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(10):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in train_iterator:
optimizer.zero_grad()
text, labels = batch.text, batch.label
predictions = model(text)
loss = criterion(predictions, labels)
acc = (torch.argmax(predictions, dim=1) == labels).float().mean()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
print(f'Epoch: {epoch+1}, Loss: {epoch_loss/len(train_iterator):.4f}, Acc: {epoch_acc/len(train_iterator):.4f}')
```
请将代码中的`path_to_csv_folder`替换为你的CSV文件夹的路径。这个示例代码假设CSV文件包含两列,一列是文本数据,另一列是标签。你可以根据你的数据集的结构进行相应的修改。
这段代码会训练一个简单的LSTM模型来对文本进行分类。你可以根据自己的需求进行更改和扩展。