Data.TensorDataset
时间: 2024-05-06 12:19:50 浏览: 89
Data.TensorDataset is a PyTorch class used for creating a dataset from one or more tensors. It is a subclass of the Dataset class and takes one or more tensors as input. Each tensor represents a different data variable. For example, in an image classification task, we may have one tensor for the images and another for the corresponding labels.
The Data.TensorDataset class provides an easy way to index and access data in the dataset. It also allows us to apply transformations to the data using PyTorch's built-in transformation functions.
To create a Data.TensorDataset object, we need to pass one or more tensors to its constructor. The tensors should have the same length along the first dimension, which represents the number of samples in the dataset.
Here's an example of how to create a Data.TensorDataset object:
```
import torch
from torch.utils.data import TensorDataset
# create two tensors for the input and target data
inputs = torch.randn(100, 3, 32, 32)
targets = torch.randint(0, 10, (100,))
# create a TensorDataset object
dataset = TensorDataset(inputs, targets)
```
In this example, we create two tensors `inputs` and `targets` representing the input data and target labels, respectively. We then create a `TensorDataset` object `dataset` by passing these tensors to its constructor. The resulting `dataset` object can be used with PyTorch's data loaders to create batches of data for training or testing machine learning models.
阅读全文