transforms.Lambda
时间: 2024-05-01 07:17:04 浏览: 96
transforms.Lambda is a PyTorch transformation that applies a user-defined function to the input data. The transformation takes a single argument, which is the input data, and returns the transformed data.
The user-defined function can be any Python function that takes the input data as its argument and returns the transformed data. This function can perform any operation on the input data, such as mathematical operations, data preprocessing, data augmentation, or any other custom operation.
Here is an example of using the transforms.Lambda transformation to apply a custom function to the input data:
```
import torch
from torchvision import transforms
# Define a custom function to apply to the input data
def custom_transform(x):
# Perform a simple data augmentation operation
x = x + 0.1 * torch.randn_like(x)
return x
# Create a Lambda transformation with the custom function
lambda_transform = transforms.Lambda(custom_transform)
# Apply the Lambda transformation to some input data
input_data = torch.randn(3, 3, 224, 224)
output_data = lambda_transform(input_data)
```
In this example, we define a custom function `custom_transform` that adds some random noise to the input data. We then create a `transforms.Lambda` transformation with this function and apply it to some input data using the `lambda_transform` object. The output data is the input data with the custom transformation applied.
阅读全文