torch.square
时间: 2023-07-11 10:29:21 浏览: 338
PyTorch的torch.cat用法
5星 · 资源好评率100%
`torch.square()` 是 PyTorch 中的一个函数,用于计算张量元素的平方。它的参数是一个张量,返回值也是一个张量,张量的每个元素都是对输入张量相应元素的平方。
下面是 `torch.square()` 的示例代码:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.square(x)
print(y) # tensor([1, 4, 9])
```
在上面的代码中,我们首先定义了一个形状为 `(3,)` 的张量 `x`,然后使用 `torch.square()` 函数对其进行平方运算,得到一个新的张量 `y`,它的每个元素都是对 `x` 相应元素的平方。
`torch.square()` 函数还可以用于计算两个张量之间的平方差,例如:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 3, 4])
z = torch.square(x - y)
print(z) # tensor([1, 1, 1])
```
在上面的代码中,我们定义了两个形状相同的张量 `x` 和 `y`,然后使用 `torch.square()` 函数计算它们之间的平方差,得到一个新的张量 `z`,它的每个元素都是对 `x` 和 `y` 相应元素之差的平方。
阅读全文