torch.mean(x, dim=1, keepdim=true)
时间: 2023-04-30 16:05:54 浏览: 218
这是 PyTorch 的函数,用于计算张量 x 沿着维度 1 的平均值,并保持维度不变。即返回一个形状为 (1, x.shape[1]) 的张量,其中平均值沿着维度 1 方向排列。如果 keepdim 参数设置为 False,则返回一个形状为 (x.shape[1],) 的张量。
相关问题
def convert_otf2psf(otf, size): ker = torch.zeros(size).cuda() #psf = torch.fft.ifftn(otf, dim=(-3,-2,-1)) #psf = torch.stack((psf.real,psf.imag),-1) psf = torch.fft.ifft2(torch.complex(otf[...,0],otf[...,1]),dim=(-3,-2,-1)) psf = psf.real # circularly shift ksize = size[-1] centre = ksize//2 + 1 print(psf.size()) print(psf.type()) ker[:, :, (centre-1):, (centre-1):] = psf[:, :, :centre, :centre]#.mean(dim=1, keepdim=True) ker[:, :, (centre-1):, :(centre-1)] = psf[:, :, :centre, -(centre-1):]#.mean(dim=1, keepdim=True) ker[:, :, :(centre-1), (centre-1):] = psf[:, :, -(centre-1):, :centre]#.mean(dim=1, keepdim=True) ker[:, :, :(centre-1), :(centre-1)] = psf[:, :, -(centre-1):, -(centre-1):]#.mean(dim=1, keepdim=True) return ker这段代码报错, Warning: Casting complex values to real discards the imaginary part (function operator()),究竟是哪行代码的问题,该如何修改呢
这个警告信息来自于以下这行代码:
```
psf = psf.real
```
这行代码会将复数张量 `psf` 的实部提取出来,舍弃掉虚部。因此,如果你要保留复数张量的信息,可以不使用这行代码。
如果你需要将复数张量转换为实数张量,可以使用以下代码:
```
psf = torch.abs(psf)
```
这行代码会将复数张量 `psf` 的模长提取出来,转换为实数张量。
另外,如果你在使用 `torch.fft.ifft2()` 函数时遇到错误,可能是因为该函数的输入张量需要是一个 `torch.complex` 类型的张量。你可以使用以下代码将 `otf` 转换为 `torch.complex` 类型:
```
otf = torch.view_as_complex(otf)
```
这行代码会将 `otf` 转换为 `torch.complex` 类型,然后就可以将其作为 `torch.fft.ifft2()` 函数的输入了。
def convert_otf2psf(otf, size): ker = torch.zeros(size).cuda() psf = torch.irfft(otf, 3, onesided=False) # circularly shift ksize = size[-1] centre = ksize//2 + 1 ker[:, :, (centre-1):, (centre-1):] = psf[:, :, :centre, :centre]#.mean(dim=1, keepdim=True) ker[:, :, (centre-1):, :(centre-1)] = psf[:, :, :centre, -(centre-1):]#.mean(dim=1, keepdim=True) ker[:, :, :(centre-1), (centre-1):] = psf[:, :, -(centre-1):, :centre]#.mean(dim=1, keepdim=True) ker[:, :, :(centre-1), :(centre-1)] = psf[:, :, -(centre-1):, -(centre-1):]#.mean(dim=1, keepdim=True) return ker假设输入的otf是四维张量,那么 psf = torch.irfft(otf, 3, onesided=False)。在pytorch1.7版本之后报错,该如何修改可以保持整段代码不报错,输出维度没有问题
在pytorch1.7版本之后,`torch.irfft()`函数的参数顺序发生了改变,需要将`onesided`参数放到第一个位置。因此,可以将代码中的`torch.irfft(otf, 3, onesided=False)`修改为`torch.irfft(otf, onesided=False, signal_ndim=3)`,这样可以保持整段代码不报错,输出维度没有问题。
阅读全文