优化这段代码:import torch.nn as nnclass MyLinear(nn.Module): def __init__(self, in_features, out_features): super(MyLinear, self).__init__() self.linears = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(5)]) def forward(self, x): outputs = [] for i in range(5): outputs.append(self.linears[i](x[:,i,:])) return torch.stack(outputs, dim=1)
时间: 2024-03-13 19:46:35 浏览: 92
可以使用 PyTorch 内置的函数 `torch.chunk()` 将输入张量沿着指定的维分成多个张量然后将它们分别输入到每个线性层中,最后再将它们沿着新的维度拼接在一起。这样就可以避免使用循环和列表,从而提高代码的效率。
下面是优化后的代码:
```
import torch.nn as nn
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super(MyLinear, self).__init__()
self.linears = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(5)])
def forward(self, x):
chunks = torch.chunk(x, 5, dim=1)
outputs = [linear(chunk) for linear, chunk in zip(self.linears, chunks)]
return torch.cat(outputs, dim=1)
```
此代码使用 `torch.chunk()` 将输入张量 `x` 沿着维度 `1` 分成 $5$ 个张量,然后使用列表推导式将每个张量分别输入到每个线性层中,得到 $5$ 个输出张量,最后使用 `torch.cat()` 将它们沿着新的维度 `1` 拼接在一起,得到最终的输出张量。
阅读全文