torch.clamp()
时间: 2023-10-12 10:11:12 浏览: 55
torch.clamp() is a PyTorch function that performs element-wise clamping of the input tensor. It takes three arguments:
- input (Tensor): The input tensor to be clamped.
- min (float or Tensor): The minimum value to clamp the elements of the input tensor to. If a tensor is passed, it must have the same shape as the input tensor.
- max (float or Tensor): The maximum value to clamp the elements of the input tensor to. If a tensor is passed, it must have the same shape as the input tensor.
The function returns a new tensor with the same shape as the input tensor, where each element is clamped between the minimum and maximum values.
For example, if we have a tensor x:
```
x = torch.Tensor([-1, 0, 1, 2, 3])
```
We can clamp its values between 0 and 2 using:
```
clamped_x = torch.clamp(x, min=0, max=2)
```
The resulting tensor clamped_x will be:
```
tensor([0., 0., 1., 2., 2.])
```