代码为import torch import torch.nn as nn class STAE(nn.Module): def __init__(self): super(STAE, self).__init__() self.c1 = nn.Sequential( nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(64), nn.MaxPool1d(kernel_size=2, stride=2), nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(128), nn.MaxPool1d(kernel_size=2, stride=2), nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(128), nn.MaxPool1d(kernel_size=2, stride=2), nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=0), nn.BatchNorm1d(256), nn.MaxPool1d(kernel_size=2, stride=2), nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(256), nn.MaxPool1d(kernel_size=2, stride=2), )#31*256 长*通道数(原先1000*1) self.r1 = nn.Sequential( nn.LSTM(input_size=30, hidden_size=64, batch_first=True), ) self.l1 = nn.Linear(64, 1) def forward(self,x): x=self.c1(x) x=self.r1(x) return x if __name__ == '__main__': a=torch.ones(1,1,1000) net=STAE() output=net(a) print(torch.Size(a))。代码出错Traceback (most recent call last): File "D:\Program Files\JetBrains\PyCharm 2023.1\PycharmProject\test\STAE.py", line 38, in <module> print(torch.Size(a)) TypeError: torch.Size() takes an iterable of 'int' (item 0 is 'Tensor')。请问如何修改
时间: 2023-09-28 20:12:05 浏览: 92
PyTorch里面的torch.nn.Parameter()详解
3星 · 编辑精心推荐
你需要修改代码中的print(torch.Size(a))这行代码,将其修改为print(a.size())即可。因为torch.Size()函数需要传入一个整数类型的可迭代对象,而a是一个Tensor类型的对象,所以会抛出TypeError异常。而a.size()函数可以直接返回a的形状信息,是一个torch.Size类型的对象,可以直接打印输出。
阅读全文