imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy(),怎么修改

时间: 2024-04-15 20:27:38 浏览: 207
ZIP

imgs.zip_imgs.com_imgs图片_web 图片展示_展示_左右滑动模

要修改这行代码,你可以使用以下方式: ```python imgs = next(iter(loader_train))[0].view(batch_size, 784).numpy() ``` 这里使用`next(iter(loader_train))`来替代`loader_train.__iter__().next()`。这两种方法是等效的,都是获取`loader_train`中的下一个batch数据。然后使用`.view(batch_size, 784)`来改变数据的形状,并最后使用`.numpy()`将数据转换为NumPy数组。
阅读全文

相关推荐

# 定义数据集读取器 def load_data(mode='train'): # 数据文件 datafile = './data/data116648/mnist.json.gz' print('loading mnist dataset from {} ......'.format(datafile)) data = json.load(gzip.open(datafile)) train_set, val_set, eval_set = data # 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS IMG_ROWS = 28 IMG_COLS = 28 if mode == 'train': imgs = train_set[0] labels = train_set[1] elif mode == 'valid': imgs = val_set[0] labels = val_set[1] elif mode == 'eval': imgs = eval_set[0] labels = eval_set[1] imgs_length = len(imgs) assert len(imgs) == len(labels), \ "length of train_imgs({}) should be the same as train_labels({})".format( len(imgs), len(labels)) index_list = list(range(imgs_length)) # 读入数据时用到的batchsize BATCHSIZE = 100 # 定义数据生成器 def data_generator(): if mode == 'train': random.shuffle(index_list) imgs_list = [] labels_list = [] for i in index_list: img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32') img_trans=-img #转变颜色 label = np.reshape(labels[i], [1]).astype('int64') label_trans=label imgs_list.append(img) imgs_list.append(img_trans) labels_list.append(label) labels_list.append(label_trans) if len(imgs_list) == BATCHSIZE: yield np.array(imgs_list), np.array(labels_list) imgs_list = [] labels_list = [] # 如果剩余数据的数目小于BATCHSIZE, # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch if len(imgs_list) > 0: yield np.array(imgs_list), np.array(labels_list) return data_generator

下面代码在tensorflow中出现了init() missing 1 required positional argument: 'cell'报错: class Model(): def init(self): self.img_seq_shape=(10,128,128,3) self.img_shape=(128,128,3) self.train_img=dataset # self.test_img=dataset_T patch = int(128 / 2 ** 4) self.disc_patch = (patch, patch, 1) self.optimizer=tf.keras.optimizers.Adam(learning_rate=0.001) self.build_generator=self.build_generator() self.build_discriminator=self.build_discriminator() self.build_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) self.build_generator.compile(loss='binary_crossentropy', optimizer=self.optimizer) img_seq_A = Input(shape=(10,128,128,3)) #输入图片 img_B = Input(shape=self.img_shape) #目标图片 fake_B = self.build_generator(img_seq_A) #生成的伪目标图片 self.build_discriminator.trainable = False valid = self.build_discriminator([img_seq_A, fake_B]) self.combined = tf.keras.models.Model([img_seq_A, img_B], [valid, fake_B]) self.combined.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1, 100], optimizer=self.optimizer,metrics=['accuracy']) def build_generator(self): def res_net(inputs, filters): x = inputs net = conv2d(x, filters // 2, (1, 1), 1) net = conv2d(net, filters, (3, 3), 1) net = net + x # net=tf.keras.layers.LeakyReLU(0.2)(net) return net def conv2d(inputs, filters, kernel_size, strides): x = tf.keras.layers.Conv2D(filters, kernel_size, strides, 'same')(inputs) x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) return x d0 = tf.keras.layers.Input(shape=(10, 128, 128, 3)) out= ConvRNN2D(filters=32, kernel_size=3,padding='same')(d0) out=tf.keras.layers.Conv2D(3,1,1,'same')(out) return keras.Model(inputs=d0, outputs=out) def build_discriminator(self): def d_layer(layer_input, filters, f_size=4, bn=True): d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) if bn: d = tf.keras.layers.BatchNormalization(momentum=0.8)(d) d = tf.keras.layers.LeakyReLU(alpha=0.2)(d) return d img_A = tf.keras.layers.Input(shape=(10, 128, 128, 3)) img_B = tf.keras.layers.Input(shape=(128, 128, 3)) df = 32 lstm_out = ConvRNN2D(filters=df, kernel_size=4, padding="same")(img_A) lstm_out = tf.keras.layers.LeakyReLU(alpha=0.2)(lstm_out) combined_imgs = tf.keras.layers.Concatenate(axis=-1)([lstm_out, img_B]) d1 = d_layer(combined_imgs, df)#64 d2 = d_layer(d1, df * 2)#32 d3 = d_layer(d2, df * 4)#16 d4 = d_layer(d3, df * 8)#8 validity = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='same')(d4) return tf.keras.Model([img_A, img_B], validity)

import torch, os, cv2 from model.model import parsingNet from utils.common import merge_config from utils.dist_utils import dist_print import torch import scipy.special, tqdm import numpy as np import torchvision.transforms as transforms from data.dataset import LaneTestDataset from data.constant import culane_row_anchor, tusimple_row_anchor if __name__ == "__main__": torch.backends.cudnn.benchmark = True args, cfg = merge_config() dist_print('start testing...') assert cfg.backbone in ['18','34','50','101','152','50next','101next','50wide','101wide'] if cfg.dataset == 'CULane': cls_num_per_lane = 18 elif cfg.dataset == 'Tusimple': cls_num_per_lane = 56 else: raise NotImplementedError net = parsingNet(pretrained = False, backbone=cfg.backbone,cls_dim = (cfg.griding_num+1,cls_num_per_lane,4), use_aux=False).cuda() # we dont need auxiliary segmentation in testing state_dict = torch.load(cfg.test_model, map_location='cpu')['model'] compatible_state_dict = {} for k, v in state_dict.items(): if 'module.' in k: compatible_state_dict[k[7:]] = v else: compatible_state_dict[k] = v net.load_state_dict(compatible_state_dict, strict=False) net.eval() img_transforms = transforms.Compose([ transforms.Resize((288, 800)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) if cfg.dataset == 'CULane': splits = ['test0_normal.txt', 'test1_crowd.txt', 'test2_hlight.txt', 'test3_shadow.txt', 'test4_noline.txt', 'test5_arrow.txt', 'test6_curve.txt', 'test7_cross.txt', 'test8_night.txt'] datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, 'list/test_split/'+split),img_transform = img_transforms) for split in splits] img_w, img_h = 1640, 590 row_anchor = culane_row_anchor elif cfg.dataset == 'Tusimple': splits = ['test.txt'] datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, split),img_transform = img_transforms) for split in splits] img_w, img_h = 1280, 720 row_anchor = tusimple_row_anchor else: raise NotImplementedError for split, dataset in zip(splits, datasets): loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle = False, num_workers=1) fourcc = cv2.VideoWriter_fourcc(*'MJPG') print(split[:-3]+'avi') vout = cv2.VideoWriter(split[:-3]+'avi', fourcc , 30.0, (img_w, img_h)) for i, data in enumerate(tqdm.tqdm(loader)): imgs, names = data imgs = imgs.cuda() with torch.no_grad(): out = net(imgs) col_sample = np.linspace(0, 800 - 1, cfg.griding_num) col_sample_w = col_sample[1] - col_sample[0] out_j = out[0].data.cpu().numpy() out_j = out_j[:, ::-1, :] prob = scipy.special.softmax(out_j[:-1, :, :], axis=0) idx = np.arange(cfg.griding_num) + 1 idx = idx.reshape(-1, 1, 1) loc = np.sum(prob * idx, axis=0) out_j = np.argmax(out_j, axis=0) loc[out_j == cfg.griding_num] = 0 out_j = loc # import pdb; pdb.set_trace() vis = cv2.imread(os.path.join(cfg.data_root,names[0])) for i in range(out_j.shape[1]): if np.sum(out_j[:, i] != 0) > 2: for k in range(out_j.shape[0]): if out_j[k, i] > 0: ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1, int(img_h * (row_anchor[cls_num_per_lane-1-k]/288)) - 1 ) cv2.circle(vis,ppp,5,(0,255,0),-1) vout.write(vis) vout.release()

from pdb import set_trace as st import os import numpy as np import cv2 import argparse parser = argparse.ArgumentParser('create image pairs') parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') args = parser.parse_args() for arg in vars(args): print('[%s] = ' % arg, getattr(args, arg)) splits = os.listdir(args.fold_A) for sp in splits: img_fold_A = os.path.join(args.fold_A, sp) img_fold_B = os.path.join(args.fold_B, sp) img_list = os.listdir(img_fold_A) if args.use_AB: img_list = [img_path for img_path in img_list if '_A.' in img_path] num_imgs = min(args.num_imgs, len(img_list)) print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) img_fold_AB = os.path.join(args.fold_AB, sp) if not os.path.isdir(img_fold_AB): os.makedirs(img_fold_AB) print('split = %s, number of images = %d' % (sp, num_imgs)) for n in range(num_imgs): name_A = img_list[n] path_A = os.path.join(img_fold_A, name_A) if args.use_AB: name_B = name_A.replace('_A.', '_B.') else: name_B = name_A path_B = os.path.join(img_fold_B, name_B) if os.path.isfile(path_A) and os.path.isfile(path_B): name_AB = name_A if args.use_AB: name_AB = name_AB.replace('_A.', '.') # remove _A path_AB = os.path.join(img_fold_AB, name_AB) im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) im_AB = np.concatenate([im_A, im_B], 1) cv2.imwrite(path_AB, im_AB),解释上述代码,并告诉我怎么设置文件夹格式

最新推荐

recommend-type

基于pytorch的UNet_demo实现及训练自己的数据集.docx

- `-b`表示批大小(batch size) - `-l`表示学习率 - `-s`表示缩放比例 - `-v`表示验证集比例 - **训练自己的数据集**: 调整数据集命名以保持一一对应,确保输入是3通道图像,输出是单通道掩模。如果遇到`...
recommend-type

变电站缺陷检测数据集,标注为VOC格式

变电站缺陷检测数据集,标注为VOC格式 表计读数有错--------bjdsyc: 657 个文件 表计外壳破损--------bj_wkps: 481 个文件 异物鸟巢--------------yw_nc: 834 个文件 箱门闭合异常--------xmbhyc: 368 个文件 盖板破损--------------gbps: 568 个文件 异物挂空悬浮物-----yw_gkxfw: 679 个文件 呼吸器硅胶变色-----hxq_gjbs: 1140 个文件 表计表盘模糊--------bj_bpmh: 828 个文件 绝缘子破裂-----------jyz_pl: 389 个文件 表计表盘破损--------bj_bpps: 694 个文件 渗漏油地面油污-----sly_dmyw: 721 个文件 未穿安全帽-----------wcaqm: 467 个文件 未穿工装--------------wcgz: 661 个文件 吸烟--------------------xy: 578 个文件
recommend-type

CIS110班级页面时钟设计与HTML实现

资源摘要信息:"clock-for-cis110:班级页面" HTML知识点: 1. HTML基础结构:HTML页面通常以<!DOCTYPE html>声明开始,紧接着是<html>标签作为根元素,包含<head>和<body>两个主要部分。在<head>部分中,一般会设置页面的元数据如标题<title>、字符集<charset>、引入外部CSS和JavaScript文件等。而<body>部分则包含页面的所有可见内容。 2. HTML文档标题<title>:标题标签用于定义页面的标题,它会显示在浏览器的标签页上,并且对于搜索引擎优化来说很重要。例如,在"clock-for-cis110:班级页面"的项目中,<title>标签的内容应该与项目相关,比如“CIS110班级时钟”。 3. HTML元素和标签:HTML文档由各种元素组成,每个元素由一个开始标签、内容和一个结束标签构成。例如,<h1>CIS110班级时钟</h1>中的<h1>是一个标题标签,用于定义最大级别的标题。 4. CSS样式应用:在HTML文档中,通常通过<link>标签在<head>部分引入外部CSS文件,这些CSS文件定义了HTML元素的样式,如字体大小、颜色、布局等。在"CIS110班级时钟"项目中,CSS将用于美化时钟的外观,例如调整时钟背景颜色、数字显示样式、时钟边框样式等。 5. JavaScript交互:为了实现动态功能,如实时显示时间的时钟,通常会在HTML文档中嵌入JavaScript代码或引入外部JavaScript文件。JavaScript可以处理时间的获取、显示以及更新等逻辑。在"CIS110班级时钟"项目中,JavaScript将用于创建时钟功能,比如让时钟能够动起来,每秒更新一次显示的时间。 6. HTML文档头部内容:在<head>部分,除了<title>外,还可以包含<meta>标签来定义页面的元数据,如字符集<meta charset="UTF-8">,这有助于确保页面在不同浏览器中的正确显示。另外,还可以添加<link rel="stylesheet" href="style.css">来引入CSS文件。 7. HTML文档主体内容:<body>部分包含了页面的所有可见元素,比如标题、段落、图片、链接以及其他各种HTML标签。在"CIS110班级时钟"项目中,主体部分将包含时钟显示区域,可能会有一个用来展示当前时间的<div>容器,以及可能的按钮、设置选项等交互元素。 通过以上知识点的介绍,我们可以了解到"CIS110班级时钟"项目的HTML页面设计需要包含哪些基本元素和技术。这些技术涉及到了文档的结构化、内容的样式定义、用户交互的设计,以及脚本编程的实现。在实际开发过程中,开发者需要结合这些知识点,进行编码以完成项目的搭建和功能实现。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

【Python沉浸式音频体验】:虚拟现实中的音频处理技巧

![【Python沉浸式音频体验】:虚拟现实中的音频处理技巧](https://www.thetechinfinite.com/wp-content/uploads/2020/07/thetechinfinite-22-1024x576.jpg) # 1. 虚拟现实中的音频处理概述 虚拟现实技术已经不再是科幻小说中的概念,而是逐渐走入了我们的生活。在这个沉浸式的世界里,除了视觉效果外,音频处理也扮演了至关重要的角色。本章将为读者提供一个虚拟现实音频处理的概览,从基础理论到实际应用,从简单的音频增强到复杂的交互设计,我们将逐步深入探讨如何在虚拟环境中实现高质量的音频体验。 虚拟现实中的音频处
recommend-type

在单片机编程中,如何正确使用if-else语句进行条件判断?请结合实际应用场景给出示例。

单片机编程中,if-else语句是基本的控制结构,用于基于条件执行不同的代码段。这在处理输入信号、状态监测、决策制定等场景中至关重要。为了帮助你更好地理解和运用这一语句,推荐参考这份资源:《单片机C语言常用语句详解ppt课件.ppt》。这份PPT课件详细讲解了单片机C语言编程中常用语句的用法和案例,直接关联到你的问题。 参考资源链接:[单片机C语言常用语句详解ppt课件.ppt](https://wenku.csdn.net/doc/5r92v3nz85?spm=1055.2569.3001.10343) 在实际应用中,if-else语句通常用于根据传感器的读数或某个标志位的状态来控制设备
recommend-type

WEB进销存管理系统wbjxc v3.0:提升企业销售与服务效率

资源摘要信息:"WEB进销存管理系统wbjxc v3.0" 知识点一:WEB进销存管理系统概念 WEB进销存管理系统是一种基于Web技术的库存管理和销售管理系统,它能够通过互联网进行数据的收集、处理和存储。该系统可以帮助企业管理商品的进货、销售、库存等信息,通过实时数据更新,确保库存信息准确,提高销售管理效率。 知识点二:产品录入、销售、退回、统计、客户管理模块 该系统包括五个基本功能模块,分别是产品录入、销售管理、退货处理、销售统计和客户信息管理。 1. 产品录入模块:负责将新产品信息加入系统,包括产品名称、价格、规格、供应商等基本信息的录入。 2. 销售管理模块:记录每一次销售活动的详细信息,包括销售商品、销售数量、销售单价、客户信息等。 3. 退回管理模块:处理商品的退货操作,记录退货商品、退货数量、退货原因等。 4. 销售统计模块:对销售数据进行汇总和分析,提供销售报表,帮助分析销售趋势和预测未来销售。 5. 客户信息管理模块:存储客户的基本信息,包括客户的联系方式、购买历史记录、信用等级等,以便于更好地服务客户和管理客户关系。 知识点三:多级别管理安全机制 "多级别管理"意味着该系统能够根据不同职位或权限的员工提供不同层级的数据访问和操作权限。这样的机制能够保护数据的安全,避免敏感信息被非授权访问或篡改。系统管理员可以设定不同的角色,如管理员、销售员、仓库管理员等,每个角色都有预设的权限,来执行特定的操作。 知识点四:操作提示及双击与单击的区别 在系统操作指南中提到需要留意单击与双击操作的区别,这通常是因为不同操作会导致不同的系统反应或功能触发。例如,在某些情况下单击可能用于打开菜单或选项,而双击可能用于立即确认或执行某个命令。用户需要根据系统的提示,正确使用单击或双击,以确保操作的准确性和系统的顺畅运行。 知识点五:Asp源码 Asp是Active Server Pages的缩写,是一种服务器端脚本环境,用于创建动态交互式网页。当Asp代码被服务器执行后,结果以HTML格式发送到客户端浏览器。使用Asp编写的应用程序可以跨平台运行在Windows系列服务器上,兼容大多数浏览器。因此,Asp源码的提及表明wbjxc v3.0系统可能使用了Asp语言进行开发,并提供了相应的源代码文件,便于开发者进行定制、维护或二次开发。 知识点六:WEB进销存系统的应用场景 WEB进销存管理系统适用于各种规模的企业,尤其适合中大型企业以及具有多个销售渠道和分销商的公司。通过互联网的特性,该系统可以方便地实现远程办公、实时数据分析以及多部门协同工作,极大地提高了工作效率和业务响应速度。 知识点七:WEB进销存系统的开发工具和语言 虽然具体的技术栈没有明确提及,但鉴于ASP源码的使用,可以推测开发wbjxc v3.0系统可能涉及的技术和工具包括但不限于:HTML、CSS、JavaScript、VBScript(Asp脚本语言的一种),以及可能的数据库技术如Microsoft SQL Server或Access数据库等。这些技术组合起来为系统提供了前端展示、后端逻辑处理以及数据存储等完整的解决方案。 知识点八:WEB进销存系统的更新和版本迭代 标题中提到的"v3.0"表明wbjxc是一个具有版本迭代的产品,随着技术进步和用户需求的变化,系统会不断更新升级以满足新的要求。版本号的递增也说明系统经过了多次更新和改进,逐渐完善功能和用户体验。用户在升级时应关注新版本带来的功能变更以及可能需要进行的数据迁移和操作习惯调整。
recommend-type

"互动学习:行动中的多样性与论文攻读经历"

多样性她- 事实上SCI NCES你的时间表ECOLEDO C Tora SC和NCESPOUR l’Ingén学习互动,互动学习以行动为中心的强化学习学会互动,互动学习,以行动为中心的强化学习计算机科学博士论文于2021年9月28日在Villeneuve d'Asq公开支持马修·瑟林评审团主席法布里斯·勒菲弗尔阿维尼翁大学教授论文指导奥利维尔·皮耶昆谷歌研究教授:智囊团论文联合主任菲利普·普雷教授,大学。里尔/CRISTAL/因里亚报告员奥利维耶·西格德索邦大学报告员卢多维奇·德诺耶教授,Facebook /索邦大学审查员越南圣迈IMT Atlantic高级讲师邀请弗洛里安·斯特鲁布博士,Deepmind对于那些及时看到自己错误的人...3谢谢你首先,我要感谢我的两位博士生导师Olivier和Philippe。奥利维尔,"站在巨人的肩膀上"这句话对你来说完全有意义了。从科学上讲,你知道在这篇论文的(许多)错误中,你是我可以依
recommend-type

Python虚拟现实网络编程:多人互动体验的设计与实现

![Python虚拟现实网络编程:多人互动体验的设计与实现](https://img-blog.csdnimg.cn/4eac4f0588334db2bfd8d056df8c263a.png) # 1. Python虚拟现实网络编程概述 在当今数字化时代,Python作为一门充满活力的编程语言,以其简洁明了的语法和强大的社区支持,在网络编程和虚拟现实(VR)应用开发领域中占据着重要的地位。Python的虚拟现实网络编程不仅结合了网络技术与VR的交互特性,还为开发者提供了一个高效、灵活的编程环境,使得实现多人互动体验变得更加可行。 ## 1.1 虚拟现实技术与网络编程的融合 虚拟现实技术通
recommend-type

在单片机编程中,如何正确使用if-else语句进行条件判断?请结合实际应用场景给出示例。

在单片机C语言编程中,if-else语句是进行条件判断的基础,它能够根据不同的条件执行不同的代码块。为了更深入地了解这一常用语句,建议参考《单片机C语言常用语句详解ppt课件.ppt》。这份课件能够提供直观的演示和具体的编码示例,帮助你更好地理解如何在单片机编程中应用if-else语句。 参考资源链接:[单片机C语言常用语句详解ppt课件.ppt](https://wenku.csdn.net/doc/5r92v3nz85?spm=1055.2569.3001.10343) 以下是一个使用if-else语句进行条件判断的实际应用场景示例:假设我们要根据一个温度传感器的输入值来控制一