pytorch 中常用深度学习模型
时间: 2025-01-03 09:40:48 浏览: 8
### PyTorch 中常用的深度学习模型及其应用
#### 卷积神经网络 (CNN)
卷积神经网络是一种专门设计用来处理具有网格结构的数据(如图像)的深层神经网络架构。这类模型通过局部连接和权值共享来减少参数数量并提高泛化能力。
```python
import torch.nn as nn
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv_layer = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc_layer = nn.Linear(32 * 13 * 13, 10)
def forward(self, x):
out = self.conv_layer(x)
out = out.view(out.size(0), -1)
out = self.fc_layer(out)
return out
```
这种类型的模型广泛应用于图像分类、目标检测等领域[^1]。
#### 循环神经网络 (RNN)
循环神经网络适合于序列数据分析,因为它们可以记住之前的信息,并将其用于当前时间步上的预测。LSTM 和 GRU 是两种改进版 RNN 结构,在解决长期依赖方面表现更好。
```python
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMModel, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
out, _ = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc(out[:, -1, :])
return out
```
这些模型常被用作自然语言处理任务中的基础构件,比如文本生成、情感分析等[^2]。
#### 变分自编码器 (VAE)
变分自编码器是一类无监督学习算法,旨在捕捉输入数据的概率分布特征。该框架由两部分组成:编码器负责将原始样本映射到潜在空间;解码器则尝试从未知分布中采样重建原图。
```python
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# Encoder layers
...
# Decoder layers
...
def encode(self, x):
mu, logvar = ... # Compute mean and variance from encoder outputs
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
recon_x = ... # Reconstruct image using decoder network
return recon_x
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
```
此模型可用于降维可视化、异常检测等多种场景下[^3]。
阅读全文