用pytorch写一段代码,实现3层LSTM。
时间: 2023-04-08 10:04:13 浏览: 63
可以的,以下是实现3层LSTM的代码:
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=x.device)
out, _ = self.lstm(x, (h0, c0))
return out
# example usage
input_size = 10
hidden_size = 20
num_layers = 3
seq_len = 5
batch_size = 2
x = torch.randn(batch_size, seq_len, input_size)
lstm = LSTM(input_size, hidden_size, num_layers)
out = lstm(x)
print(out.shape) # should be (batch_size, seq_len, hidden_size)