Traceback (most recent call last): File "train.py", line 79, in <module> train(opt, data_loader, model, visualizer) File "train.py", line 27, in train ssimMetric = SSIM(results['Restored_Train'], results['Sharp_Train']) File "/home/chenantai/DeblurGAN-master/DeblurGAN-master4/util/metrics.py", line 22, in SSIM mu1 = F.conv2d(img1, window, padding = window_size/2, groups = channel) TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not numpy.ndarray
时间: 2024-04-10 20:27:02 浏览: 92
linux 下 pip3 报错“ File “/usr/bin/pip3”, line 9, in from pip import mainImportError: canno”处理
5星 · 资源好评率100%
这个错误是因为在SSIM函数中的conv2d()函数调用中,输入的img1参数类型应为Tensor,而不是numpy.ndarray。你需要将img1转换为Tensor类型。
你可以使用torch.from_numpy()函数将numpy数组转换为Tensor。在调用conv2d()函数之前,将img1转换为Tensor类型,并且确保img1的形状和数据类型与使用时的要求相匹配。
例如,你可以这样修改代码:
```python
import torch
# ...
def SSIM(img1, img2):
# ...
img1_tensor = torch.from_numpy(img1).to(torch.float32)
# ...
mu1 = F.conv2d(img1_tensor, window, padding=window_size/2, groups=channel)
# ...
```
这样应该能解决这个TypeError。记得将其他相关变量也进行相应的类型转换。
阅读全文