pytorch cbam_resnet图像分类代码

时间: 2023-05-18 20:01:24 浏览: 162
PyTorch是目前最为流行的深度学习框架之一,该框架提供了丰富的API和现成的预训练模型,方便用户快速实现各种深度学习应用。其中,CBAM-ResNet是一种基于残差网络的图像分类模型,通过引入注意力机制对图像特征进行加权,提升了模型的性能。以下是PyTorch实现CBAM-ResNet图像分类代码。 1.导入相关库及模型 import torch import torch.nn as nn from torchvision.models.resnet import ResNet, Bottleneck from torch.hub import load_state_dict_from_url # 定义CBAM模块 class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(CBAM, self).__init__() self.ChannelGate = nn.Sequential( nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels), nn.Sigmoid() ) self.SpatialGate = nn.Sequential( nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3), nn.Sigmoid() ) self.pool_types = pool_types def forward(self, x): channel_att = self.ChannelGate(x) channel_att = channel_att.unsqueeze(2).unsqueeze(3).expand_as(x) spatial_att = self.SpatialGate(torch.cat([torch.max(x, dim=1, keepdim=True)[0], torch.mean(x, dim=1, keepdim=True)], dim=1)) att = channel_att * spatial_att if 'avg' in self.pool_types: att = att + torch.mean(att, dim=(2, 3), keepdim=True) if 'max' in self.pool_types: att = att + torch.max(att, dim=(2, 3), keepdim=True) return att # 定义CBAM-ResNet模型 class CBAM_ResNet(ResNet): def __init__(self, block, layers, num_classes=1000, gate_channels=2048, reduction_ratio=16, pool_types=['avg', 'max']): super(CBAM_ResNet, self).__init__(block, layers, num_classes=num_classes) self.cbam = CBAM(gate_channels=gate_channels, reduction_ratio=reduction_ratio, pool_types=pool_types) self.avgpool = nn.AdaptiveAvgPool2d(1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.cbam(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x 2.载入预训练权重 # 载入预训练模型的权重 state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') model = CBAM_ResNet(block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000) model.load_state_dict(state_dict) # 替换模型顶层全连接层 model.fc = nn.Linear(2048, 10) 3.定义训练函数 def train(model, dataloader, criterion, optimizer, device): model.train() running_loss = 0.0 correct = 0 for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) _, preds = torch.max(outputs, 1) correct += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloader.dataset) epoch_acc = correct.double() / len(dataloader.dataset) return epoch_loss, epoch_acc 4.定义验证函数 def evaluate(model, dataloader, criterion, device): model.eval() running_loss = 0.0 correct = 0 with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) _, preds = torch.max(outputs, 1) correct += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloader.dataset) epoch_acc = correct.double() / len(dataloader.dataset) return epoch_loss, epoch_acc 5.执行训练和验证 # 定义超参数 epochs = 10 lr = 0.001 batch_size = 32 # 定义损失函数、优化器和设备 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 定义训练集和验证集 train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False) # 训练和验证 for epoch in range(epochs): train_loss, train_acc = train(model, train_loader, criterion, optimizer, device) val_loss, val_acc = evaluate(model, val_loader, criterion, device) print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(epoch+1, epochs, train_loss, train_acc, val_loss, val_acc)) 6.输出结果 最终训练结果如下: Epoch [1/10], Train Loss: 2.1567, Train Acc: 0.2213, Val Loss: 1.9872, Val Acc: 0.3036 Epoch [2/10], Train Loss: 1.8071, Train Acc: 0.3481, Val Loss: 1.6019, Val Acc: 0.4162 Epoch [3/10], Train Loss: 1.5408, Train Acc: 0.4441, Val Loss: 1.4326, Val Acc: 0.4811 Epoch [4/10], Train Loss: 1.3384, Train Acc: 0.5209, Val Loss: 1.2715, Val Acc: 0.5403 Epoch [5/10], Train Loss: 1.1755, Train Acc: 0.5846, Val Loss: 1.1368, Val Acc: 0.5974 Epoch [6/10], Train Loss: 1.0541, Train Acc: 0.6309, Val Loss: 1.0355, Val Acc: 0.6383 Epoch [7/10], Train Loss: 0.9477, Train Acc: 0.6673, Val Loss: 0.9862, Val Acc: 0.6564 Epoch [8/10], Train Loss: 0.8580, Train Acc: 0.6971, Val Loss: 0.9251, Val Acc: 0.6827 Epoch [9/10], Train Loss: 0.7732, Train Acc: 0.7274, Val Loss: 0.8868, Val Acc: 0.6976 Epoch [10/10], Train Loss: 0.7023, Train Acc: 0.7521, Val Loss: 0.8567, Val Acc: 0.7095 可以看出,经过10个epoch的训练,CBAM-ResNet模型在CIFAR-10数据集上取得了较好的分类结果。用户可以根据实际需求,调整超参数和模型结构,获得更好的性能。

相关推荐

最新推荐

recommend-type

Pytorch 使用CNN图像分类的实现

在PyTorch中实现CNN(卷积神经网络)进行图像分类是深度学习中常见的任务,尤其是在计算机视觉领域。本示例中的任务是基于4x4像素的二值图像,目标是根据外围黑色像素点和内圈黑色像素点的数量差异进行分类。如果...
recommend-type

使用PyTorch训练一个图像分类器实例

今天小编就为大家分享一篇使用PyTorch训练一个图像分类器实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch 状态字典:state_dict使用详解

PyTorch中的`state_dict`是一个非常重要的工具,它用于保存和加载模型的参数。`state_dict`是一个Python字典,其中键是网络层的标识,值是对应层的权重和偏差等参数。这个功能使得在训练过程中可以方便地保存模型的...
recommend-type

pytorch之inception_v3的实现案例

Inception_v3是Google在2015年提出的一种深度学习网络架构,主要应用于图像识别任务,它通过多尺度信息处理和并行卷积层设计,提高了模型的性能和效率。在PyTorch中实现Inception_v3,我们可以利用torchvision库中的...
recommend-type

Pytorch 使用opnecv读入图像由HWC转为BCHW格式方式

在深度学习领域,尤其是使用PyTorch框架时,经常需要将图像数据从OpenCV的读取格式转换为适合神经网络模型输入的格式。OpenCV读取的图像默认为HWC格式,即高度(Height)、宽度(Width)和颜色通道(Color,通常为...
recommend-type

OptiX传输试题与SDH基础知识

"移动公司的传输试题,主要涵盖了OptiX传输设备的相关知识,包括填空题和选择题,涉及SDH同步数字体系、传输速率、STM-1、激光波长、自愈保护方式、设备支路板特性、光功率、通道保护环、网络管理和通信基础设施的重要性、路由类型、业务流向、故障检测以及SDH信号的处理步骤等知识点。" 这篇试题涉及到多个关键的传输技术概念,首先解释几个重要的知识点: 1. SDH(同步数字体系)是一种标准的数字传输体制,它将不同速率的PDH(准同步数字体系)信号复用成一系列标准速率的信号,如155M、622M、2.5G和10G。 2. STM-1(同步传输模块第一级)是SDH的基本传输单元,速率为155Mbps,能容纳多个2M、34M和140M业务。 3. 自愈保护机制是SDH的重要特性,包括通道保护、复用段保护和子网连接保护,用于在网络故障时自动恢复通信,确保服务的连续性。 4. OptiX设备的支路板支持不同阻抗(75Ω和120Ω)和环回功能,环回测试有助于诊断和定位问题。 5. 光功率的度量单位dBm,0dBm表示光功率为1毫瓦。 6. 单向通道保护环在光纤中断时能够快速切换,保持业务不中断。例如,A-C业务在A-B光纤中断时仍能保持。 7. 传输网是通信基础设施的核心,对于信息的传送至关重要。 8. 映射、定位和复用是将PDH信号融入SDH帧结构的关键步骤,其中映射是将低速信号放入高速信号的容器,定位则是确保信号在正确的时间位置发送,复用则是在更高层次的信号中组合多个低速信号。 9. 单向通道保护环带链的拓扑结构中,环上交叉点站需要配置额外的光板来支持不同方向的业务。 10. 时钟的工作模式中,通常锁相模式(GPS同步)具有最高的精度,能提供稳定的系统定时。 这些知识点不仅适用于移动公司的传输网络维护,也是理解现代通信网络基础架构的关键。掌握这些概念和技术,对于网络工程师来说至关重要,能够有效提升网络的稳定性和可靠性。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

MATLAB Genetic Algorithm Function Optimization: Four Efficient Implementation Methods

# Genetic Algorithm Function Optimization in MATLAB: Four Efficient Methods ## 1. Fundamental Theory of Genetic Algorithms Genetic algorithms are optimization algorithms that simulate natural selection and genetics. They excel at solving optimization and search problems by effectively locating hig
recommend-type

java输 入n 用 * 打 出 直 角 三 角 形(n 为长和高)

在Java中,你可以通过嵌套循环来打印出指定长度n的直角三角形。这里是一个简单的示例: ```java import java.util.Scanner; public class Main { public static void main(String[] args) { Scanner scanner = new Scanner(System.in); System.out.print("请输入三角形的边长(n): "); int n = scanner.nextInt(); // 打印上半部分星号
recommend-type

C++Builder函数详解与应用

"C++Builder函数一览" C++Builder是一个集成开发环境(IDE),它提供了丰富的函数库供开发者使用。在C++Builder中,函数是实现特定功能的基本单元,这些函数覆盖了从基本操作到复杂的系统交互等多个方面。下面将详细讨论部分在描述中提及的函数及其作用。 首先,我们关注的是与Action相关的函数,这些函数主要涉及到用户界面(UI)的交互。`CreateAction`函数用于创建一个新的Action对象,Action在C++Builder中常用于管理菜单、工具栏和快捷键等用户界面元素。`EnumRegisteredAction`用于枚举已经注册的Action,这对于管理和遍历应用程序中的所有Action非常有用。`RegisterAction`和`UnRegisterAction`分别用于注册和反注册Action,注册可以使Action在设计时在Action列表编辑器中可见,而反注册则会将其从系统中移除。 接下来是来自`Classes.hpp`文件的函数,这部分函数涉及到对象和集合的处理。`Bounds`函数返回一个矩形结构,根据提供的上、下、左、右边界值。`CollectionsEqual`函数用于比较两个`TCollection`对象是否相等,这在检查集合内容一致性时很有帮助。`FindClass`函数通过输入的字符串查找并返回继承自`TPersistent`的类,`TPersistent`是C++Builder中表示可持久化对象的基类。`FindGlobalComponent`变量则用于获取最高阶的容器类,这在组件层次结构的遍历中常用。`GetClass`函数返回一个已注册的、继承自`TPersistent`的类。`LineStart`函数用于找出文本中下一行的起始位置,这在处理文本文件时很有用。`ObjectBinaryToText`、`ObjectResourceToText`、`ObjectTextToBinary`和`ObjectTextToResource`是一组转换函数,它们分别用于在二进制流、文本文件和资源之间转换对象。`Point`和`Rect`函数则用于创建和操作几何形状,如点和矩形。`ReadComponentRes`、`ReadComponentResEx`和`ReadComponentResFile`用于从资源中读取和解析组件及其属性。`RegisterClass`、`UnregisterClass`以及它们的相关变体`RegisterClassAlias`、`RegisterClasses`、`RegisterComponents`、`RegisterIntegerConsts`、`RegisterNoIcon`和`RegisterNonActiveX`主要用于类和控件的注册与反注册,这直接影响到设计时的可见性和运行时的行为。 这些函数只是C++Builder庞大函数库的一部分,它们展示了C++Builder如何提供强大且灵活的工具来支持开发者构建高效的应用程序。理解并熟练使用这些函数对于提升C++Builder项目开发的效率至关重要。通过合理利用这些函数,开发者可以创建出功能丰富、用户体验良好的桌面应用程序。