给出一个基于CNN的GAN的对二位数据进行扩充的pytorch代码,该代码还应包括训练损失、测试损失的绘制,混淆矩阵的绘制,以及t-sne可视化

时间: 2024-02-27 16:51:28 浏览: 23
以下是一个基于CNN的GAN的对二维数据进行扩充的PyTorch代码,包括训练和测试损失的绘制,混淆矩阵的绘制和t-sne可视化。 ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision.datasets as datasets import torchvision.transforms as transforms import matplotlib.pyplot as plt from sklearn.manifold import TSNE from sklearn.metrics import confusion_matrix import seaborn as sns import numpy as np # 定义生成器 class Generator(nn.Module): def __init__(self, input_dim=10, output_dim=2, hidden_dim=128): super(Generator, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_dim=2, output_dim=1, hidden_dim=128): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义训练函数 def train(discriminator, generator, train_loader, criterion, d_optimizer, g_optimizer, num_epochs): d_losses = [] g_losses = [] for epoch in range(num_epochs): d_loss = 0.0 g_loss = 0.0 for i, (real_samples, _) in enumerate(train_loader): batch_size = real_samples.size(0) real_samples = real_samples.view(batch_size, -1) real_samples = real_samples.to(device) # 训练判别器 d_optimizer.zero_grad() d_real = discriminator(real_samples) real_labels = torch.ones(batch_size, 1).to(device) d_real_loss = criterion(d_real, real_labels) z = torch.randn(batch_size, 10).to(device) fake_samples = generator(z) d_fake = discriminator(fake_samples) fake_labels = torch.zeros(batch_size, 1).to(device) d_fake_loss = criterion(d_fake, fake_labels) d_loss_batch = d_real_loss + d_fake_loss d_loss_batch.backward() d_optimizer.step() # 训练生成器 g_optimizer.zero_grad() z = torch.randn(batch_size, 10).to(device) fake_samples = generator(z) d_fake = discriminator(fake_samples) real_labels = torch.ones(batch_size, 1).to(device) g_loss_batch = criterion(d_fake, real_labels) g_loss_batch.backward() g_optimizer.step() d_loss += d_loss_batch.item() g_loss += g_loss_batch.item() d_loss /= len(train_loader) g_loss /= len(train_loader) d_losses.append(d_loss) g_losses.append(g_loss) print("Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}".format(epoch+1, num_epochs, d_loss, g_loss)) return d_losses, g_losses # 定义测试函数 def test(discriminator, generator, test_loader, criterion): discriminator.eval() generator.eval() with torch.no_grad(): y_true = [] y_pred = [] for i, (real_samples, labels) in enumerate(test_loader): batch_size = real_samples.size(0) real_samples = real_samples.view(batch_size, -1) real_samples = real_samples.to(device) d_real = discriminator(real_samples) y_true.extend(labels.tolist()) y_pred.extend(torch.round(d_real).tolist()) cm = confusion_matrix(y_true, y_pred) sns.heatmap(cm, annot=True, fmt='g') plt.xlabel('Predicted label') plt.ylabel('True label') plt.show() z = torch.randn(1000, 10).to(device) fake_samples = generator(z) fake_samples = fake_samples.cpu().numpy() plt.scatter(fake_samples[:,0], fake_samples[:,1], s=5, c='r') plt.show() tsne = TSNE(n_components=2) fake_samples_tsne = tsne.fit_transform(fake_samples) plt.scatter(fake_samples_tsne[:,0], fake_samples_tsne[:,1], s=5, c='r') plt.show() # 定义超参数 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_dim = 10 output_dim = 2 hidden_dim = 128 batch_size = 64 num_epochs = 100 lr = 0.0002 beta1 = 0.5 # 加载数据集 train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 初始化模型、优化器和损失函数 generator = Generator(input_dim, output_dim, hidden_dim).to(device) discriminator = Discriminator(output_dim, 1, hidden_dim).to(device) g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) criterion = nn.BCELoss() # 训练模型 d_losses, g_losses = train(discriminator, generator, train_loader, criterion, d_optimizer, g_optimizer, num_epochs) # 测试模型 test(discriminator, generator, test_loader, criterion) ``` 在训练过程中,每个epoch会输出判别器和生成器的损失。训练结束后,会绘制混淆矩阵、生成的数据的散点图和t-sne可视化图。

相关推荐

最新推荐

recommend-type

pytorch实现对输入超过三通道的数据进行训练

今天小编就为大家分享一篇pytorch实现对输入超过三通道的数据进行训练,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

用Pytorch训练CNN(数据集MNIST,使用GPU的方法)

今天小编就为大家分享一篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

PyTorch版YOLOv4训练自己的数据集—基于Google Colab

Google Colaboratory是谷歌开放的一款研究工具,主要用于机器学习的开发和研究。 工具优势:Google Colab最大的好处是给广大的AI开发者提供了免费的GPU使用。你可以在上面轻松地跑例如:Keras、Tensorflow、Pytorch...
recommend-type

基于pytorch的UNet_demo实现及训练自己的数据集.docx

基于pytorch的UNet分割网络demo实现,及训练自己的数据集。包括对相关报错的分析。收集了几个比较好的前辈的网址。
recommend-type

pytorch sampler对数据进行采样的实现

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,...
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

用 Python 画一个可以动的爱心

好的,我可以帮助您使用Python来绘制一个动态的爱心。您可以使用turtle库来实现。以下是实现代码: ```python import turtle import math # 设置画布和画笔 canvas = turtle.Screen() canvas.bgcolor("black") pencil = turtle.Turtle() pencil.speed(0) pencil.color("red", "pink") pencil.pensize(3) # 定义爱心函数 def draw_love(heart_size, x_offset=0, y_offset=0):
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。