import torch.nn as nn
时间: 2024-09-25 17:06:24 浏览: 66
当你在Python编程中引入`torch.nn`模块时,你实际上是导入了PyTorch库中的神经网络(Neural Network)功能部分。`nn`是PyTorch神经网络模块的别名,它提供了一系列预定义的层(Layers)、激活函数(Activations)、损失函数(Loss Functions)以及优化器(Optimizers)等用于构建深度学习模型的基础组件。
在这个上下文中,`nn.Module`是一个基础类,所有的自定义神经网络层都应该继承自这个类。例如,你可以创建一个新的卷积层`Conv2d`或者全连接层`Linear`,它们都会自动包含训练时所需的前向传播和反向传播(Forward and Backward Propagation)机制。
常见的操作包括定义网络结构、设置超参数、初始化权重和偏置、以及定义计算流程(如`forward()`方法)。此外,`nn`模块还包含了像`MSELoss`, `CrossEntropyLoss`这样的常用损失函数,以及`Adam`, `SGD`等优化器。
```python
# 示例:定义一个简单的全连接网络
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x)) # 使用ReLU激活
x = self.fc2(x)
return x
model = MyModel(input_size=784, hidden_size=512, output_size=10) # 创建网络实例
```
阅读全文