flatten pytorch
时间: 2023-09-27 10:07:55 浏览: 118
In PyTorch, you can flatten a tensor using the `view` method or the `flatten` function.
Using the `view` method:
```python
import torch
# Create a tensor with a shape of (batch_size, channels, height, width)
input_tensor = torch.randn(10, 3, 32, 32)
# Flatten the tensor using the view method
flattened_tensor = input_tensor.view(input_tensor.size(0), -1)
# Print the shape of the flattened tensor
print(flattened_tensor.shape)
```
Using the `flatten` function:
```python
import torch
# Create a tensor with a shape of (batch_size, channels, height, width)
input_tensor = torch.randn(10, 3, 32, 32)
# Flatten the tensor using the flatten function
flattened_tensor = torch.flatten(input_tensor, start_dim=1)
# Print the shape of the flattened tensor
print(flattened_tensor.shape)
```
Both methods achieve the same result of flattening the tensor. The resulting tensor will have a shape of (batch_size, channels * height * width), where `batch_size` is the same as the original tensor and `channels * height * width` represents the flattened dimension.
阅读全文