def test(checkpoint_dir, style_name, test_dir, if_adjust_brightness, img_size=[256,256]): # tf.reset_default_graph() result_dir = 'results/'+style_name check_folder(result_dir) test_files = glob('{}/*.*'.format(test_dir)) test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test') with tf.variable_scope("generator", reuse=False): test_generated = generator.G_net(test_real).fake saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess: # tf.global_variables_initializer().run() # load model ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name))) else: print(" [*] Failed to find a checkpoint") return # stats_graph(tf.get_default_graph()) begin = time.time() for sample_file in tqdm(test_files) : # print('Processing image: ' + sample_file) sample_image = np.asarray(load_test_data(sample_file, img_size)) image_path = os.path.join(result_dir,'{0}'.format(os.path.basename(sample_file))) fake_img = sess.run(test_generated, feed_dict = {test_real : sample_image}) if if_adjust_brightness: save_images(fake_img, image_path, sample_file) else: save_images(fake_img, image_path, None) end = time.time() print(f'test-time: {end-begin} s') print(f'one image test time : {(end-begin)/len(test_files)} s'什么意思
时间: 2024-03-31 16:38:00 浏览: 105
这是一个 Python 函数,该函数接受五个参数:`checkpoint_dir`,`style_name`,`test_dir`,`if_adjust_brightness` 和 `img_size`。该函数使用 TensorFlow 实现了一个测试函数,用于测试训练好的 AnimeGAN V2 模型生成的效果。在函数中,首先使用 `glob` 函数获取测试图片的路径列表,然后使用 TensorFlow 定义了一个占位符 `test_real`,用于输入测试图片。接着,使用 AnimeGAN V2 模型生成器 `generator.G_net` 生成虚假图像 `test_generated`。随后,使用 TensorFlow 定义了一个 Saver 对象,用于从检查点文件中恢复模型。最后,使用 TensorFlow Session 对象运行生成器模型,生成对应的虚假图像,并将其保存到指定的目录中。
其中,如果 `if_adjust_brightness` 为真,则调用 `save_images` 函数保存虚假图像,并调用 `load_test_data` 函数加载测试图片并调整亮度;否则,调用 `save_images` 函数保存虚假图像,但不会调用 `load_test_data` 函数调整亮度。函数还输出了测试的总时间和每张图片测试所需的时间。
相关问题
def Hayao(): arg=parse_args(?,gl_path_) test(arg.checkpoint_dir, arg.save_dir, arg.test_dir, arg.if_adjust_brightness) def parse_args(a,b): desc = "AnimeGANv2" parser = argparse.ArgumentParser(description=desc) parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/'+a, help='Directory name to save the checkpoints') parser.add_argument('--test_dir', type=str, default=b, help='Directory name of test photos') parser.add_argument('--save_dir', type=str, default='Shinkai/t', help='what style you want to get') parser.add_argument('--if_adjust_brightness', type=bool, default=True, help='adjust brightness by the real photo') """checking arguments""" return parser.parse_args()优化此代码
我们可以对代码进行如下优化:
1. 将函数 `Hayao` 中的参数 `a` 和 `b` 直接传入,而不是使用 `?` 占位符。
2. 将 `gl_path_` 的定义看不到,根据代码的语义,我们可以将其作为参数传入 `Hayao` 函数中。
3. 将 `arg.if_adjust_brightness` 的类型从 `bool` 改为 `int`,使得输入参数更加灵活。
经过上述优化后,代码可以改写为:
```python
def Hayao(checkpoint_dir, test_dir, save_dir, if_adjust_brightness, gl_path_):
arg = parse_args(checkpoint_dir, test_dir, save_dir, if_adjust_brightness, gl_path_)
test(arg.checkpoint_dir, arg.save_dir, arg.test_dir, arg.if_adjust_brightness)
def parse_args(checkpoint_dir, test_dir, save_dir, if_adjust_brightness, gl_path_):
desc = "AnimeGANv2"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/' + checkpoint_dir,
help='Directory name to save the checkpoints')
parser.add_argument('--test_dir', type=str, default=test_dir,
help='Directory name of test photos')
parser.add_argument('--save_dir', type=str, default=save_dir,
help='what style you want to get')
parser.add_argument('--if_adjust_brightness', type=int, default=1,
help='0: not adjust brightness, 1: adjust brightness by the real photo')
"""checking arguments"""
return parser.parse_args()
```
这样就使得代码更加清晰易懂,同时也更加灵活。
def Hayao(): path.set(gl_path_) t='generator_Hayao_weight' arg=parse_args(t,gl_path_) test(arg.checkpoint_dir, arg.save_dir, arg.test_dir, arg.if_adjust_brightness) def parse_args(a,b): desc = "AnimeGANv2" parser = argparse.ArgumentParser(description=desc) parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/'+a, help='Directory name to save the checkpoints') parser.add_argument('--test_dir', type=str, default=b, help='Directory name of test photos') parser.add_argument('--save_dir', type=str, default='Shinkai/t', help='what style you want to get') parser.add_argument('--if_adjust_brightness', type=bool, default=True, help='adjust brightness by the real photo') """checking arguments""" return parser.parse_args()代码报错修改
这段代码报错可能是因为你没有导入`argparse`模块。你需要在代码开头加上`import argparse`语句。同时,你需要将`gl_path_`传入函数`Hayao()`中,否则会出现`NameError`错误。修改后的代码如下:
```
import argparse
def Hayao(gl_path_):
path.set(gl_path_)
t='generator_Hayao_weight'
arg=parse_args(t,gl_path_)
test(arg.checkpoint_dir, arg.save_dir, arg.test_dir, arg.if_adjust_brightness)
def parse_args(a,b):
desc = "AnimeGANv2"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/'+a,
help='Directory name to save the checkpoints')
parser.add_argument('--test_dir', type=str, default=b,
help='Directory name of test photos')
parser.add_argument('--save_dir', type=str, default='Shinkai/t',
help='what style you want to get')
parser.add_argument('--if_adjust_brightness', type=bool, default=True,
help='adjust brightness by the real photo')
"""checking arguments"""
return parser.parse_args()
```
调用`Hayao()`函数时,需要将`gl_path_`作为参数传入。
阅读全文