使用mnist制作moving mnist

时间: 2023-12-05 17:50:36 浏览: 27
Moving MNIST是一个标准的序列预测数据集,它由一系列MNIST图像组成,这些图像随着时间的推移发生变化。每个时间步骤包含两个MNIST图像,一个前景图像和一个背景图像,前景图像在背景图像上运动。 以下是使用Python和Pytorch制作Moving MNIST的示例代码: ``` python import torch from torchvision import datasets, transforms import numpy as np def generate_moving_mnist(batch_size, seq_length, num_digits=2, image_size=64, digit_size=28, max_speed=1.0): mnist_transforms = transforms.Compose([ transforms.Scale(digit_size), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) mnist_dataset = datasets.MNIST('data', train=True, download=True, transform=mnist_transforms) mnist_data = mnist_dataset.train_data.numpy() # Generate random sequences sequences = np.zeros((batch_size, seq_length, num_digits, image_size, image_size), dtype=np.float32) for i in range(batch_size): for j in range(num_digits): # Choose a random digit from the MNIST dataset digit_idx = np.random.randint(0, len(mnist_data)) digit = mnist_data[digit_idx] # Choose a random position and velocity for the digit x = np.random.randint(0, image_size - digit_size) y = np.random.randint(0, image_size - digit_size) dx = np.random.uniform(-max_speed, max_speed) dy = np.random.uniform(-max_speed, max_speed) # Generate the sequence for this digit for t in range(seq_length): # Compute the current position of the digit x += dx y += dy # If the digit goes off the edge of the image, bounce it back if x < 0: x = 0 dx = -dx elif x + digit_size > image_size: x = image_size - digit_size dx = -dx if y < 0: y = 0 dy = -dy elif y + digit_size > image_size: y = image_size - digit_size dy = -dy # Add the digit to the sequence at this time step sequences[i, t, j, x:x+digit_size, y:y+digit_size] = digit.numpy() # Normalize the sequences sequences -= np.mean(sequences) sequences /= np.std(sequences) # Convert to torch tensor sequences = torch.from_numpy(sequences) return sequences ``` 上述代码使用了MNIST数据集中的样本图像,并根据随机速度和位置生成Moving MNIST序列。在此之后,对生成的序列进行标准化并转换为PyTorch张量。 可以通过调用以下代码来使用上述函数生成Moving MNIST序列: ``` python batch_size = 16 seq_length = 20 num_digits = 2 image_size = 64 digit_size = 28 max_speed = 1.0 sequences = generate_moving_mnist(batch_size, seq_length, num_digits, image_size, digit_size, max_speed) print(sequences.shape) ``` 在上面的代码中,我们生成了一个大小为16x20x2x64x64的Moving MNIST序列。每个序列都包含20个时间步长和2个MNIST数字,这些数字在序列的时间步长中随机移动。

相关推荐

import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进

最新推荐

recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch实现mnist分类的示例讲解

今天小编就为大家分享一篇pytorch实现mnist分类的示例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

深度学习的mnist实验报告

关于深度学习mnist文件的实验报告, 具体包括 xehll安装过程和mnist文件的执行过程。
recommend-type

使用tensorflow实现VGG网络,训练mnist数据集方式

主要介绍了使用tensorflow实现VGG网络,训练mnist数据集方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

用Pytorch训练CNN(数据集MNIST,使用GPU的方法)

今天小编就为大家分享一篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

用 Python 画一个可以动的爱心

好的,我可以帮助您使用Python来绘制一个动态的爱心。您可以使用turtle库来实现。以下是实现代码: ```python import turtle import math # 设置画布和画笔 canvas = turtle.Screen() canvas.bgcolor("black") pencil = turtle.Turtle() pencil.speed(0) pencil.color("red", "pink") pencil.pensize(3) # 定义爱心函数 def draw_love(heart_size, x_offset=0, y_offset=0):
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。