TypeError: Linear.__init__() got an unexpected keyword argument 'device'
时间: 2023-10-01 11:06:00 浏览: 214
这个错误通常是因为 PyTorch 版本过低导致的。在较旧的版本中,`nn.Linear` 类确实不支持 `device` 参数。您可以尝试更新 PyTorch 到最新版本,或者使用以下方法解决此问题:
```python
import torch.nn as nn
import torch
class MyLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, input):
return super().forward(input.to(self.device))
```
这里我们创建了一个名为 `MyLinear` 的子类,它继承自 `nn.Linear`。在 `__init__` 方法中,我们首先调用 `super().__init__(in_features, out_features, bias)` 来初始化父类,然后手动设置了 `device` 属性,该属性根据 CUDA 是否可用而设置为 `"cuda"` 或 `"cpu"`。在 `forward` 方法中,我们首先将输入张量移动到我们指定的设备,然后调用父类的 `forward` 方法,最后返回输出张量。
阅读全文