base_model = tf.keras.applications.MobileNet(weights = "imagenet", include_top = False, input_shape = input_shape) base_model.trainable = False inputs = keras.Input(shape = input_shape) x = base_model(inputs, training = False) x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dropout(0.2)(x) x = tf.keras.layers.Dense(len(categories), activation="softmax")(x) model = keras.Model(inputs = inputs, outputs = x, name="LeafDisease_MobileNet") weight_path = os.path.join(base_dir, 'checkpoints', 'my_checkpoint') model.load_weights(weight_path) img = plt.imread(img_path) img = img / 255. img = cv2.resize(img, (224, 224)) img = img.reshape(-1, 224, 224, 3) img.astype('float32') result = model.predict(img) cate_result = categories[np.argmax(result, axis=1)[0]] return cate_result这段代码中的权重文件是在哪里加载的
时间: 2023-11-09 21:06:41 浏览: 395
tf.keras_分类模块_CNN-深度可分离.ipynb_nose8eu_keras_CNN_tf.keras_分离卷积_
这段代码中的权重文件是在 `weight_path = os.path.join(base_dir, 'checkpoints', 'my_checkpoint')` 这一行加载的。在这里,`base_dir` 是指程序的根目录,`checkpoints` 是指存放权重文件的目录,`my_checkpoint` 是指具体的权重文件名称。`model.load_weights(weight_path)` 用于加载权重文件到模型中。
阅读全文