torch.clamp函数
时间: 2023-06-21 18:03:48 浏览: 211
torch.clamp函数是PyTorch中的一个函数,用于将张量的值限制在一个指定的范围内。具体来说,它将张量中的每个元素限制在[min, max]的范围内,小于min的元素被替换为min,大于max的元素被替换为max。其函数原型如下:
```
torch.clamp(input, min, max, out=None) -> Tensor
```
其中,input表示要进行限制的张量,min表示下限值,max表示上限值,out表示输出张量(可以为空)。如果out不为空,则将结果存储在out中,否则返回一个新的张量。
示例代码:
```python
import torch
a = torch.randn(3, 3)
print(a)
# 将张量a中的元素限制在[-1, 1]的范围内
b = torch.clamp(a, -1, 1)
print(b)
```
输出结果:
```
tensor([[ 0.3649, -0.7078, -0.2539],
[-0.5877, -0.3386, -0.5367],
[ 0.4279, 0.9133, -1.3270]])
tensor([[ 0.3649, -0.7078, -0.2539],
[-0.5877, -0.3386, -0.5367],
[ 0.4279, 0.9133, -1.0000]])
```
在上面的示例中,我们首先使用torch.randn函数生成了一个形状为(3, 3)的随机张量a,然后使用torch.clamp函数将a中的元素限制在[-1, 1]的范围内,得到了新的张量b。可以看到,b中的元素都在[-1, 1]之间。
阅读全文