python Clamp
时间: 2024-01-06 15:25:50 浏览: 129
`torch.clamp()`函数是PyTorch中的一个函数,用于将输入张量的元素限制在指定范围内。它的语法如下:
```python
torch.clamp(input, min, max, out=None) → Tensor
```
其中,`input`是输入的张量,`min`和`max`分别是限制的下界和上界。函数会将`input`中小于`min`的元素替换为`min`,大于`max`的元素替换为`max`,而在`min`和`max`之间的元素保持不变。
以下是一个使用`torch.clamp()`函数的例子:
```python
import torch
a = torch.randint(low=0, high=10, size=(10, 1))
print(a)
b = torch.clamp(a, 3, 9)
print(b)
```
输出结果为:
```
tensor([[7],
[2],
[6],
[4],
[5],
[0],
[8],
[9],
[1],
[3]])
tensor([[7],
[3],
[6],
[4],
[5],
[3],
[8],
[9],
[3],
[3]])
```
在上面的例子中,我们首先生成了一个形状为(10, 1)的随机整数张量`a`,然后使用`torch.clamp()`函数将`a`中小于3的元素替换为3,大于9的元素替换为9,得到了张量`b`。
阅读全文