pytorch基础使用
时间: 2024-09-18 12:04:44 浏览: 41
Testing out basic pytorch
PyTorch是一款开源的深度学习框架,它基于Python语言,并提供动态计算图的能力,这使得模型构建和调试相对直观。以下是PyTorch的一些基本使用:
1. **安装与导入**:首先需要安装PyTorch库,可以使用pip进行安装。导入`torch`库后,可以创建张量(类似于NumPy数组,但支持GPU运算)。
```python
import torch
```
2. **张量操作**:张量是PyTorch的核心数据结构,你可以对它们进行基本的数学运算、索引和切片等操作。
```python
x = torch.tensor([1, 2, 3]) # 创建一个一维张量
y = torch.tensor([[4, 5], [6, 7]]) # 创建一个二维张量
```
3. **自动求导**:PyTorch内置了autograd机制,用于自动计算梯度,这对于训练神经网络至关重要。
```python
z = x * y
grad = torch.autograd.grad(z.sum(), [x, y]) # 计算z关于x和y的梯度
```
4. **定义和训练模型**:通过定义nn.Module类来创建模型,然后设置优化器(如SGD或Adam)和损失函数(如交叉熵),进行前向传播、反向传播并更新权重。
```python
class Net(nn.Module):
def __init__(self): ... # 初始化模型
def forward(self, input): ... # 定义前向传播
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
input_tensor = torch.randn(1, 3, 224, 224) # 假设输入图片
output = model(input_tensor)
loss = criterion(output, target) # 假设target是对应标签
loss.backward() # 反向传播
optimizer.step() # 更新权重
```
阅读全文