import tensorflow as tf from im_dataset import train_image, train_label, test_image, test_label from AlexNet8 import AlexNet8 from baseline import baseline from InceptionNet import Inception10 from Resnet18 import ResNet18 import os import matplotlib.pyplot as plt import argparse import numpy as np parse = argparse.ArgumentParser(description="CVAE model for generation of metamaterial") hyperparameter_set = parse.add_argument_group(title='HyperParameter Setting') dim_set = parse.add_argument_group(title='Dim setting') hyperparameter_set.add_argument("--num_epochs",type=int,default=200,help="Number of train epochs") hyperparameter_set.add_argument("--learning_rate",type=float,default=4e-3,help="learning rate") hyperparameter_set.add_argument("--image_size",type=int,default=16*16,help="vector size of image") hyperparameter_set.add_argument("--batch_size",type=int,default=16,help="batch size of database") dim_set.add_argument("--z_dim",type=int,default=20,help="dim of latent variable") dim_set.add_argument("--feature_dim",type=int,default=32,help="dim of feature vector") dim_set.add_argument("--phase_curve_dim",type=int,default=41,help="dim of phase curve vector") dim_set.add_argument("--image_dim",type=int,default=16,help="image size: [image_dim,image_dim,1]") args = parse.parse_args() def preprocess(x, y): x = tf.io.read_file(x) x = tf.image.decode_png(x, channels=1) x = tf.cast(x,dtype=tf.float32) /255. x1 = tf.concat([x, x], 0) x2 = tf.concat([x1, x1], 1) x = x - 0.5 y = tf.convert_to_tensor(y) y = tf.cast(y,dtype=tf.float32) return x2, y train_db = tf.data.Dataset.from_tensor_slices((train_image, train_label)) train_db = train_db.shuffle(100).map(preprocess).batch(args.batch_size) test_db = tf.data.Dataset.from_tensor_slices((test_image, test_label)) test_db = test_db.map(preprocess).batch(args.batch_size) model = ResNet18([2, 2, 2, 2]) model.build(input_shape=(args.batch_size, 32, 32, 1)) model.compile(optimizer = tf.keras.optimizers.Adam(lr = 1e-3), loss = tf.keras.losses.MSE, metrics = ['MSE']) checkpoint_save_path = "./checkpoint/InceptionNet_im_3/checkpoint.ckpt" if os.path.exists(checkpoint_save_path+'.index'): print('------------------load the model---------------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True) history = model.fit(train_db, epochs=500, validation_data=test_db, validation_freq=1, callbacks=[cp_callback]) model.summary() acc = history.history['loss'] val_acc = history.history['val_loss'] plt.plot(acc, label='Training MSE') plt.plot(val_acc, label='Validation MSE') plt.title('Training and Validation MSE') plt.legend() plt.show()
时间: 2023-06-17 22:03:28 浏览: 117
这段代码是使用 TensorFlow 构建了一个 ResNet18 模型,对图像进行分类任务。其中使用了 argparse 库来设置超参数,使用了 tf.data.Dataset 来进行数据处理和加载,使用了 tf.keras.callbacks.ModelCheckpoint 来保存模型的训练结果。同时还使用了 matplotlib.pyplot 库来绘制训练和验证的 MSE(均方误差)曲线。
相关问题
mport socket from collections import OrderedDict import numpy as np import tensorflow as tf import utils.io.image import utils.io.landmark import utils.io.text import utils.np_image import utils.sitk_image import utils.sitk_np from dataset import Dataset from datasets.pyro_dataset import PyroClientDataset from network import Unet from spine_localization_postprocessing import bb, bb_iou from tensorflow.keras.mixed_precision import experimental as mixed_precision from tensorflow_train_v2.dataset.dataset_iterator import DatasetIterator from tensorflow_train_v2.train_loop import MainLoopBase from tensorflow_train_v2.utils.data_format import get_batch_channel_image_size from tensorflow_train_v2.utils.loss_metric_logger import LossMetricLogger from tensorflow_train_v2.utils.output_folder_handler import OutputFolderHandler from tqdm import tqdm
这段代码看起来是一个Python脚本,导入了一些需要的模块和库,包括socket、collections、numpy、tensorflow等。它还引用了一些自定义的模块和类,如utils.io、dataset、network等。代码中还使用了一些函数和类,如Unet、bb、bb_iou等。最后还使用了tqdm库来显示进度条。这段代码似乎是与图像处理、数据集处理和训练相关的。请问你有什么具体的问题或需要进一步的解释吗?
from datasets.pyro_dataset import PyroClientDataset from network import Unet from spine_localization_postprocessing import bb, bb_iou from tensorflow.keras.mixed_precision import experimental as mixed_precision from tensorflow_train_v2.dataset.dataset_iterator import DatasetIterator from tensorflow_train_v2.train_loop import MainLoopBase from tensorflow_train_v2.utils.data_format import get_batch_channel_image_size from tensorflow_train_v2.utils.loss_metric_logger import LossMetricLogger from tensorflow_train_v2.utils.output_folder_handler import OutputFolderHandler from tqdm import tqdm
这段代码看起来是一个深度学习项目的一部分。它导入了一些自定义的模块和库,包括PyroClientDataset、Unet、spine_localization_postprocessing、mixed_precision等。它还使用了一些来自TensorFlow的模块和函数,如DatasetIterator、MainLoopBase、LossMetricLogger等。代码中还使用了tqdm库来显示进度条。
这段代码的具体功能和用途可能需要查看其他部分的代码才能确定。
阅读全文