使用cnn对桌面已有的名字为sleep的文件中的Excel传感器数据进行睡眠分期实验,使用pytorch并写出完整代码并标明注释

时间: 2025-03-17 08:18:51 浏览: 16

嗯,用户想用CNN和PyTorch对名为“sleep”的Excel文件中的传感器数据进行睡眠分期实验。首先,我需要确认数据结构和内容。假设Excel文件里有时间序列的传感器数据,比如加速度、心率等,以及对应的睡眠阶段标签,比如清醒、REM、浅睡、深睡等。

接下来,得考虑数据预处理。可能需要读取Excel文件,检查是否有缺失值,进行标准化或归一化。然后,将数据转换成适合CNN输入的格式,比如时间序列分段,可能用滑动窗口方法生成样本。

然后是数据集划分,分为训练集、验证集和测试集。需要确保数据分布均衡,特别是不同睡眠阶段的样本数量可能不均衡,可能需要采用分层抽样。

模型方面,使用CNN处理时间序列数据,可能需要一维卷积层。结构可能包括多个卷积层、池化层,最后接全连接层进行分类。激活函数选ReLU,输出层用Softmax。损失函数用交叉熵,优化器用Adam。

训练过程中要监控准确率和损失,可能需要早停法防止过拟合。还要考虑批处理大小、学习率等超参数调整。

代码部分需要包括数据加载、预处理、数据集创建、模型定义、训练循环和评估。注释要清晰,说明每一步的作用。可能还需要数据可视化,比如损失曲线和混淆矩阵。

需要注意的是,实际应用中可能需要调整网络结构,比如卷积核大小、层数,或者加入BatchNorm层和Dropout来提升性能。另外,数据增强方法也可能有用,比如添加噪声或时间扭曲,以增加模型的泛化能力。

可能遇到的挑战包括数据量不足、类别不平衡、特征提取是否充分。这时候可以考虑数据增强、调整类别权重,或者尝试其他模型如LSTM结合CNN。此外,确保数据预处理步骤正确,比如标准化方法是否适合传感器数据。

最后,测试模型时要看混淆矩阵,分析各个睡眠阶段的分类效果,找出模型可能存在的弱点,进一步优化。好的!我将按照步骤为您构建基于CNN的睡眠分期实验代码。以下是详细实现(代码包含详细注释):

"""
基于CNN的睡眠分期实验代码
环境要求:Python 3.8+ / PyTorch 1.12+ / pandas / scikit-learn
"""

# 〇、环境准备
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 一、数据预处理
# 假设sleep.xlsx结构:包含多列传感器数据和最后一列'sleep_stage'(0-4分类标签)
def load_data(file_path):
    """
    加载并预处理Excel数据
    参数:file_path - Excel文件路径
    返回:处理后的特征和标签
    """
    # 读取数据
    df = pd.read_excel(file_path)
    
    # 处理缺失值(根据实际情况选择填充方式)
    df = df.fillna(method='ffill')  # 前向填充
    
    # 分离特征和标签
    features = df.iloc[:, :-1].values  # 假设最后一列是标签
    labels = df.iloc[:, -1].values
    
    # 标准化特征
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    
    return features, labels

# 二、数据集构建
class SleepDataset(Dataset):
    """自定义睡眠数据集类"""
    def __init__(self, features, labels, window_size=30):
        """
        参数:
            features: 标准化后的特征数组
            labels: 对应标签数组
            window_size: 滑动窗口大小(时间步长)
        """
        self.features = []
        self.labels = []
        
        # 使用滑动窗口生成序列样本
        for i in range(len(features) - window_size):
            self.features.append(features[i:i+window_size])
            self.labels.append(labels[i+window_size])  # 预测当前窗口最后一个时间点的状态
            
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return (
            torch.FloatTensor(self.features[idx]).permute(1, 0),  # 转换为(channel, time_step)
            torch.LongTensor([self.labels[idx]])
        )

# 三、CNN模型定义
class SleepCNN(nn.Module):
    """睡眠分期CNN模型"""
    def __init__(self, input_channels, num_classes):
        super().__init__()
        
        self.cnn = nn.Sequential(
            # 输入形状:(batch_size, input_channels, time_steps)
            nn.Conv1d(input_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # 全局平均池化
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)  # 展平
        return self.classifier(x)

# 四、训练配置
# 参数设置
BATCH_SIZE = 64
EPOCHS = 100
LEARNING_RATE = 0.001
WINDOW_SIZE = 30  # 30个时间点为一个样本

# 加载数据
features, labels = load_data("sleep.xlsx")

# 划分数据集
X_train, X_temp, y_train, y_temp = train_test_split(features, labels, test_size=0.3, stratify=labels)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp)

# 创建数据集
train_dataset = SleepDataset(X_train, y_train, WINDOW_SIZE)
val_dataset = SleepDataset(X_val, y_val, WINDOW_SIZE)
test_dataset = SleepDataset(X_test, y_test, WINDOW_SIZE)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SleepCNN(input_channels=X_train.shape[1], num_classes=len(np.unique(labels))).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 五、训练循环
best_val_acc = 0
train_losses = []
val_accuracies = []

for epoch in range(EPOCHS):
    # 训练阶段
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device).squeeze()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # 验证阶段
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device).squeeze()
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # 记录指标
    epoch_loss = running_loss / len(train_loader)
    val_acc = correct / total
    train_losses.append(epoch_loss)
    val_accuracies.append(val_acc)
    
    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc:.4f}")

# 六、结果可视化
# 绘制训练曲线
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Training Loss')
plt.title("Training Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1,2,2)
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title("Validation Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()

# 七、测试评估
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_correct = 0
test_total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device).squeeze()
        
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print(f"\nTest Accuracy: {test_correct/test_total:.4f}")

# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

关键改进点说明:

  1. 数据预处理:加入滑动窗口处理,更适合时序数据分析
  2. 模型结构:使用深层CNN+全局池化,避免全连接层过拟合
  3. 评估机制:加入早停法(通过保存最佳模型实现)和混淆矩阵可视化
  4. 设备兼容:自动检测GPU加速

使用注意事项:

  1. 需要根据实际数据调整:
    • 修改load_data()中的特征列索引
    • 调整window_size参数匹配采样频率
    • 修改num_classes参数匹配实际分期数
  2. 建议预处理步骤:
    • 检查数据平衡性,必要时进行过采样/欠采样
    • 添加更复杂的数据增强(如添加噪声、时间扭曲)
  3. 性能提升方向:
    • 加入BatchNorm层
    • 尝试LSTM-CNN混合架构
    • 调整超参数(学习率、窗口大小等)

实际应用时,请确保:

  1. Excel数据格式正确
  2. 已安装所有依赖库
  3. 有足够的GPU显存(如需处理大数据集)
  4. 根据实际需求调整分类阈值或损失函数权重
向AI提问 loading 发送消息图标
大学生入口

大家在看

recommend-type

图像的均方误差的matlab代码-alexandrelab_celltrackingcode:alexandrelab_celltrackin

图像的均方误差的matlab代码亚历山大实验室的细胞追踪 通讯员: 首席研究员:Gladys Alexandre- 实验室经理:Elena Ganusov- 代码作者:Mustafa Elmas() Lam Vo-(个人:),Tanmoy Mukherjee() 引文 作者:Mustafa Elmas 日期:08/01/2017 隶属:田纳西大学-诺克斯维尔 目的: 分析细菌运动视频并找到I)细胞速度(微米/秒)II)细胞反转频率(/ s)III)均方根位移(MSD) 将录制的视频分割成一定数量的帧 将帧转换为二进制帧 通过MATLAB内置函数regiongroup计算质心,长轴和短轴的长度和角度。 根据Crocker和Grier的MATLAB版本的单元跟踪算法,在连续视频帧中离散时间确定的粒子坐标的加扰列表的加扰列表中,构造n维轨迹。 低于10微米/秒且短于1 s的轨迹被排除在分析之外。 这样可以确保我们将分析主要限制在焦平面周围狭窄区域内的轨迹上。 计算速度,反转频率,加速度,角加速度,速度自相关,均方根位移 先决条件: MATLAB版本R2019a – MATLAB版本很重要,因
recommend-type

PRBS7码型.TXT

鉴于很多朋友咨询我Verilog-A语言实现PRBS7码型的代码,今天有空把他上传上来,和大家分享讨论一起学习
recommend-type

swftest.zip

MFC加载指定的flash.ocx, 跑页游, 与系统注册的ocx不是一个, 但是貌似是不成功的, 请高人帮我看一看, 请高人帮我改正并传我一份工程
recommend-type

Keysight IO程序套件,2021版本

keysight IO程序套件(ACCELERATE INSTRUMENT CONNECTION AND CONTROL WITH IO LIBRARIES SUITE);IO Libraries Suite ,版本:2021
recommend-type

blind beamforming.rar

盲波束形成算法matlab程序(含恒模CMA、高阶累积量CUM、循环累积量CYC、二阶累积量MRE)

最新推荐

recommend-type

在Pytorch中使用Mask R-CNN进行实例分割操作

【PyTorch中使用Mask R-CNN进行实例分割】 实例分割是计算机视觉领域的一个关键任务,它旨在识别图像中每个像素所属的对象实例。不同于语义分割,实例分割不仅标识像素的类别,还能区分同一类的不同实例。Mask R-...
recommend-type

用Pytorch训练CNN(数据集MNIST,使用GPU的方法)

在本文中,我们将探讨如何使用PyTorch训练一个卷积神经网络(CNN)模型,针对MNIST数据集,并利用GPU加速计算。MNIST是一个包含手写数字图像的数据集,常用于入门级的深度学习项目。PyTorch是一个灵活且用户友好的...
recommend-type

pytorch实现对输入超过三通道的数据进行训练

在`__getitem__`方法中,根据索引从对应文件中加载数据,将其转换为Tensor,并返回数据和对应的标签。最后,`__len__`方法返回数据集的总样本数。 数据加载部分,使用`DataLoader`来创建一个数据迭代器,指定批量...
recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

在本教程中,我们将探讨如何使用PyTorch框架来实现条件生成对抗网络(CGAN)并利用MNIST数据集生成指定数字的图像。CGAN是一种扩展了基础生成对抗网络(GAN)的概念,它允许在生成过程中加入额外的条件信息,如类...
recommend-type

pytorch下使用LSTM神经网络写诗实例

在本文中,我们将探讨如何使用PyTorch实现一个基于LSTM(Long Short-Term Memory)神经网络的诗歌生成系统。LSTM是一种递归神经网络(RNN)变体,特别适合处理序列数据,如文本,因为它能有效地捕获长期依赖性。 ...
recommend-type

iOS开发中的HTTP请求方法演示

在iOS开发中,进行HTTP请求以从服务器获取数据是常见的任务。在本知识点梳理中,我们将详细探讨如何利用HTTP向服务器请求数据,涵盖同步GET请求、同步POST请求、异步GET请求以及异步POST请求,并将通过示例代码来加深理解。 ### 同步GET请求 同步GET请求是指客户端在发起请求后将阻塞当前线程直到服务器响应返回,期间用户界面无法进行交互。这种做法不推荐在主线程中使用,因为会造成UI卡顿。下面是一个使用`URLSession`进行同步GET请求的示例代码。 ```swift import Foundation func syncGETRequest() { guard let url = URL(string: "http://www.example.com/api/data") else { return } var request = URLRequest(url: url) request.httpMethod = "GET" let task = URLSession.shared.dataTask(with: request) { data, response, error in if let error = error { print("Error: \(error)") return } if let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode) { guard let mimeType = httpResponse.mimeType, mimeType == "application/json" else { print("Invalid content-type") return } guard let data = data else { print("No data") return } do { let json = try JSONSerialization.jsonObject(with: data, options: []) print("Data received: \(json)") } catch { print("JSONSerialization failed: \(error)") } } else { print("HTTP Error: \(response?.description ?? "No response")") } } task.resume() } // 调用函数 syncGETRequest() ``` ### 同步POST请求 同步POST请求与GET类似,但是在请求方法、请求体以及可能的参数设置上有所不同。下面是一个同步POST请求的示例代码。 ```swift import Foundation func syncPOSTRequest() { guard let url = URL(string: "http://www.example.com/api/data") else { return } var request = URLRequest(url: url) request.httpMethod = "POST" let postData = "key1=value1&key2=value2" request.httpBody = postData.data(using: .utf8) let task = URLSession.shared.dataTask(with: request) { data, response, error in // 同GET请求处理方式类似... } task.resume() } // 调用函数 syncPOSTRequest() ``` ### 异步GET请求 异步请求不会阻塞主线程,因此可以提升用户体验。在iOS开发中,可以使用`URLSession`来发起异步请求。 ```swift import Foundation func asyncGETRequest() { guard let url = URL(string: "http://www.example.com/api/data") else { return } var request = URLRequest(url: url) request.httpMethod = "GET" URLSession.shared.dataTask(with: request) { data, response, error in // 同步GET请求处理方式类似... }.resume() } // 调用函数 asyncGETRequest() ``` ### 异步POST请求 异步POST请求的代码结构与GET请求类似,区别主要在于HTTP方法和请求体的设置。 ```swift import Foundation func asyncPOSTRequest() { guard let url = URL(string: "http://www.example.com/api/data") else { return } var request = URLRequest(url: url) request.httpMethod = "POST" let postData = "key1=value1&key2=value2" request.httpBody = postData.data(using: .utf8) URLSession.shared.dataTask(with: request) { data, response, error in // 同步GET请求处理方式类似... }.resume() } // 调用函数 asyncPOSTRequest() ``` ### 注意事项 - **网络权限**:在实际项目中,需要在`Info.plist`中添加相应的网络权限。 - **错误处理**:示例代码中展示了基本的错误处理流程,但在实际开发中应详细处理各种可能的网络错误。 - **线程安全**:如果在主线程之外的线程更新UI,需要确保线程安全。 - **请求超时**:在网络请求中设置合理的超时时间以避免长时间无响应。 - **状态码处理**:服务器响应的状态码需要适当处理,如200表示成功,4xx表示客户端错误,5xx表示服务器端错误。 ### 结语 通过上述示例代码,新手开发者可以快速理解如何在iOS项目中使用HTTP请求与服务器进行数据交互。无论是同步还是异步请求,重要的是要根据实际应用场景合理选择请求方式,并严格遵循最佳实践,确保应用的性能和用户体验。
recommend-type

【精准测试】:确保分层数据流图准确性的完整测试方法

# 摘要 分层数据流图(DFD)作为软件工程中描述系统功能和数据流动的重要工具,其测试方法论的完善是确保系统稳定性的关键。本文系统性地介绍了分层DFD的基础知识、测试策略与实践、自动化与优化方法,以及实际案例分析。文章详细阐述了测试的理论基础,包括定义、目的、分类和方法,并深入探讨了静态与动态测试方法以及测试用
recommend-type

错误: 找不到或无法加载主类 org.springblade.Application 原因: java.lang.ClassNotFoundException: org.springblade.Application

<think>好的,我现在需要帮助用户解决Java程序中出现的ClassNotFoundException,特别是关于org.springblade.Application类的找不到的问题。用户提到这是SpringBoot的主类未找到的情况。首先,我需要回忆一下这个异常的常见原因和解决方法。 根据之前的引用内容,ClassNotFoundException通常与依赖配置或类路径有关。例如,引用[2]指出这可能是因为依赖配置不正确或类路径设置问题,而引用[3]提到版本不统一也可能导致此类问题。此外,主类未找到还可能是因为打包配置不正确,比如在Maven或Gradle中没有正确指定主类,或者在构
recommend-type

个人作品集展示:HTML文件夹压缩处理

根据给定的文件信息,我们可以推断出以下IT知识内容。 ### 知识点一:HTML文件夹的作用与结构 HTML文件夹通常用于存放网站的所有相关文件,包括HTML文件、CSS样式表、JavaScript脚本、图像文件以及其他资源文件。这个文件夹的结构应该清晰且有组织,以便于开发和维护。HTML文件是网页内容的骨架,它通过标签(Tag)来定义内容的布局和结构。 #### HTML标签的基本概念 HTML标签是构成网页的基石,它们是一些用尖括号包围的词,如`<html>`, `<head>`, `<title>`, `<body>`等。这些标签告诉浏览器如何显示网页上的信息。例如,`<img>`标签用于嵌入图像,而`<a>`标签用于创建超链接。HTML5是最新版本的HTML,它引入了更多的语义化标签,比如`<article>`, `<section>`, `<nav>`, `<header>`, `<footer>`等,这有助于提供更丰富的网页结构信息。 #### 知识点二:使用HTML构建投资组合(portfolio) “portfolio”一词在IT行业中常常指的是个人或公司的作品集。这通常包括了一个人或组织在特定领域的工作样本和成就展示。使用HTML创建“portfolio”通常会涉及到以下几个方面: - 设计布局:决定页面的结构,如导航栏、内容区域、页脚等。 - 网页内容的填充:使用HTML标签编写内容,可能包括文本、图片、视频和链接。 - 网站响应式设计:确保网站在不同设备上都能有良好的浏览体验,这可能涉及到使用CSS媒体查询和弹性布局。 - CSS样式的应用:为HTML元素添加样式,使网页看起来更加美观。 - JavaScript交互:添加动态功能,如图片画廊、滑动效果或导航菜单。 #### 知识点三:GitHub Pages与网站托管 标题中出现的"gh-pages"表明涉及的是GitHub Pages。GitHub Pages是GitHub提供的一个静态网站托管服务。用户可以使用GitHub Pages托管他们的个人、组织或者项目的页面。它允许用户直接从GitHub仓库部署和发布网站。 #### 知识点四:项目命名与管理 在压缩包子文件的文件名称列表中,出现了"portfolio-gh-pages",这说明项目有一个特定的命名规范。文件夹或项目名称应该简洁明了,能够反映项目内容或者用途。在IT项目管理中,良好的命名习惯有助于团队成员更快地理解项目的性质,同时也方便版本控制和代码维护。 #### 总结 在信息技术领域,使用HTML构建一个投资组合网站是一个常见的任务。它不仅可以展示个人或公司的技能和作品,还可以作为与潜在客户或雇主交流的平台。理解HTML标签的使用、网页设计的基本原则、响应式设计以及网站托管服务,对于制作一个专业且吸引人的投资组合至关重要。此外,良好的项目命名和文件管理习惯也是IT专业人士应该具备的基本技能之一。
recommend-type

【版本控制】:分层数据流图的高效维护与变更管理

# 摘要 本文系统地探讨了版本控制和分层数据流图设计的重要性和应用实践。第一章强调版本控制的基础知识和其在软件开发生命周期中的关键作用。第二章详细介绍了分层数据流图的设计原理,包括基本概念、设计方法和表示技巧,以及如何通过这些图解高效地管理和沟通软件设计。第三章探讨了版本控制系统的选择与配置,比较了不同类型系统的特点,并提供了配置主流系统的实际案例。第四章重点讨论分层数据流图的变更管理流程,阐述
手机看
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回
顶部