resnet18一维
时间: 2025-01-09 14:53:38 浏览: 1
### 使用 PyTorch 和 TensorFlow 实现 ResNet18 对一维数据的应用
#### 一、PyTorch 中的一维 ResNet18 实现
对于处理一维信号的数据,如时间序列或音频信号,可以通过修改标准的二维卷积操作来适应一维输入。以下是使用 PyTorch 构建适用于一维数据的 ResNet18 的方法:
```python
import torch.nn as nn
from torchvision import models
class ResNet18_1D(nn.Module):
def __init__(input_channels=1, num_classes=10):
super().__init__()
# 加载预训练权重并调整首层卷积核大小以匹配一维输入
base_model = models.resnet18(pretrained=True)
self.conv1 = nn.Conv1d(
input_channels,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False
)
# 替换原始模型的第一层卷积层
layers = list(base_model.children())[1:]
self.feature_extractor = nn.Sequential(*layers)
# 修改全连接层以适配新的类别数量
fc_in_features = base_model.fc.in_features
self.fc = nn.Linear(fc_in_features, num_classes)
def forward(x):
x = x.unsqueeze(-1).transpose(1, 2) # 调整维度顺序 (batch, channels, length)
x = self.conv1(x)
x = self.feature_extractor(x.squeeze(dim=-1))
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
```
此代码片段展示了如何创建一个继承自 `nn.Module` 类的新类,并利用预训练好的 ResNet18 来初始化大部分网络结构,仅改变最开始的部分使其能够接受一维张量作为输入。
#### 二、TensorFlow/Keras 中的一维 ResNet18 实现
同样,在 TensorFlow 中也可以构建类似的架构用于处理一维数据。这里给出了一种方式来定义适合于一维输入的 ResNet18 模型:
```python
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, BatchNormalization, ReLU, GlobalAveragePooling1D, Dense
from tensorflow.keras.models import Model
def residual_block_1d(filters, downsample=False, strides=(1,), name=None):
identity = inputs
if isinstance(strides, int):
strides = [strides]
conv_layer = Conv1D(
filters=filters,
kernel_size=3,
strides=strides[0],
padding='same',
use_bias=False,
name=name + '_conv'
)
bn_layer = BatchNormalization(name=name + '_bn')
shortcut_conv = None
if downsample or inputs.shape[-1] != filters:
shortcut_conv = Conv1D(
filters=filters,
kernel_size=1,
strides=strides[0],
padding='valid',
use_bias=False,
name=name + '_shortcut_conv'
)
shortcut_bn = BatchNormalization(name=name + '_shortcut_bn')
net = conv_layer(inputs)
net = bn_layer(net)
net = ReLU()(net)
net = Conv1D(
filters=filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
name=name + '_conv2'
)(net)
net = BatchNormalization(name=name + '_bn2')(net)
if shortcut_conv is not None:
identity = shortcut_conv(identity)
identity = shortcut_bn(identity)
output = tf.add(net, identity)
return ReLU()(output)
def build_resnet18_1d(input_shape, classes_num):
inputs = tf.keras.Input(shape=input_shape, name="inputs")
x = Conv1D(
filters=64,
kernel_size=7,
strides=2,
padding='same',
use_bias=False,
name='initial_conv'
)(inputs)
x = BatchNormalization(name='initial_bn')(x)
x = ReLU()(x)
x = MaxPool1D(pool_size=3, strides=2)(x)
block_filters = [64]*2 + [128]*2 + [256]*2 + [512]*2
for i, f in enumerate(block_filters):
downsampling = True if i % 2 == 0 and i > 0 else False
x = residual_block_1d(f, downsampling, name=f'block_{i//2}_{i%2}')(x)
x = GlobalAveragePooling1D()(x)
outputs = Dense(classes_num, activation='softmax', name='predictions')(x)
model = Model(inputs=inputs, outputs=outputs, name="ResNet18_1D")
return model
```
上述代码中定义了一个辅助函数 `residual_block_1d()` 来简化残差块的设计过程,并最终组合成完整的 ResNet18 结构[^1]。
阅读全文