CRNN torch
时间: 2024-12-27 11:23:12 浏览: 4
### 使用 PyTorch 实现 CRNN 模型
CRNN (Convolutional Recurrent Neural Network) 是一种用于序列识别任务的强大架构,在光学字符识别(OCR)等领域表现优异。下面展示如何利用 PyTorch 来构建这样一个网络。
#### 构建卷积层部分
首先定义 CNN 部分,负责提取图像特征:
```python
import torch.nn as nn
class CRNN_CNN(nn.Module):
def __init__(self, img_channel, output_size):
super(CRNN_CNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(img_channel, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)),
nn.AdaptiveAvgPool2d(output_size=(None,output_size))
)
def forward(self, input):
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2) # remove h dimension, [b, c, w] => [b,c,h*w]
conv = conv.permute(2, 0, 1) #[seq_len,batch_size,num_classes]
return conv
```
此代码创建了一个多层CNN结构,它接受输入图片并输出一系列特征向量[^3]。
#### 添加循环神经网络(RNN)组件
接着是 RNN 层的设计,通常采用双向 LSTM 或 GRU 单元来处理由 CNN 提取出来的特征序列:
```python
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
```
这部分实现了双向 LSTM 结构,可以更好地捕捉时间维度上的依赖关系。
#### 完整的 CRNN 类
最后组合上述两部分形成完整的 CRNN 模型类:
```python
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn_kwargs = {'leakyRelu': leakyRelu}
self.cnn = CRNN_CNN(nc, imgH//16-1)
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
# conv features
conv = self.cnn(input)
# rnn features
logits = self.rnn(conv)
return logits
```
这段代码完成了整个 CRNN 的搭建工作,其中包含了前面提到的 CNN 和 RNN 组件,并设置了合适的参数配置。
阅读全文