帮我挑出文档中所有的伪代码,只要伪代码
时间: 2024-11-03 22:08:28 浏览: 13
以下是文档中所有的伪代码:
### AlexNet_v1 模型定义
```python
def AlexNet_v1(im_height=224, im_width=224, num_classes=1000):
input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32")
x = layers.ZeroPadding2D(((1, 2), (1, 2)))(input_image)
x = layers.Conv2D(48, kernel_size=11, strides=4, activation="relu")(x)
x = layers.MaxPool2D(pool_size=3, strides=2)(x)
x = layers.Conv2D(128, kernel_size=5, padding="same", activation="relu")(x)
x = layers.MaxPool2D(pool_size=3, strides=2)(x)
x = layers.Conv2D(192, kernel_size=3, padding="same", activation="relu")(x)
x = layers.Conv2D(192, kernel_size=3, padding="same", activation="relu")(x)
x = layers.Conv2D(128, kernel_size=3, padding="same", activation="relu")(x)
x = layers.MaxPool2D(pool_size=3, strides=2)(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(2048, activation="relu")(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(2048, activation="relu")(x)
x = layers.Dense(num_classes)(x)
predict = layers.Softmax()(x)
model = models.Model(inputs=input_image, outputs=predict)
return model
```
### AlexNet_v2 模型定义
```python
class AlexNet_v2(Model):
def __init__(self, num_classes=1000):
super(AlexNet_v2, self).__init__()
self.features = Sequential([
layers.ZeroPadding2D(((1, 2), (1, 2))),
layers.Conv2D(48, kernel_size=11, strides=4, activation="relu"),
layers.MaxPool2D(pool_size=3, strides=2),
layers.Conv2D(128, kernel_size=5, padding="same", activation="relu"),
layers.MaxPool2D(pool_size=3, strides=2),
layers.Conv2D(192, kernel_size=3, padding="same", activation="relu"),
layers.Conv2D(192, kernel_size=3, padding="same", activation="relu"),
layers.Conv2D(128, kernel_size=3, padding="same", activation="relu"),
layers.MaxPool2D(pool_size=3, strides=2)
])
self.flatten = layers.Flatten()
self.classifier = Sequential([
layers.Dropout(0.2),
layers.Dense(1024, activation="relu"),
layers.Dropout(0.2),
layers.Dense(128, activation="relu"),
layers.Dense(num_classes),
layers.Softmax()
])
def call(self, inputs, **kwargs):
x = self.features(inputs)
x = self.flatten(x)
x = self.classifier(x)
return x
```
### 训练代码
```python
def main():
data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))
image_path = os.path.join(data_root, "data_set", "flower_data")
train_dir = os.path.join(image_path, "train")
validation_dir = os.path.join(image_path, "val")
assert os.path.exists(train_dir), "cannot find {}".format(train_dir)
assert os.path.exists(validation_dir), "cannot find {}".format(validation_dir)
if not os.path.exists("save_weights"):
os.makedirs("save_weights")
im_height = 224
im_width = 224
batch_size = 32
epochs = 33
train_image_generator = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True)
validation_image_generator = ImageDataGenerator(rescale=1. / 255)
train_data_gen = train_image_generator.flow_from_directory(
directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical'
)
total_train = train_data_gen.n
class_indices = train_data_gen.class_indices
inverse_dict = dict((val, key) for key, val in class_indices.items())
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
val_data_gen = validation_image_generator.flow_from_directory(
directory=validation_dir,
batch_size=batch_size,
shuffle=False,
target_size=(im_height, im_width),
class_mode='categorical'
)
total_val = val_data_gen.n
print("using {} images for training, {} images for validation.".format(total_train, total_val))
model = AlexNet_v2(num_classes=5)
model.build((batch_size, 224, 224, 3))
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)
history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["accuracy"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_accuracy"]
plt.figure()
plt.plot(range(epochs), train_loss, label='train_loss')
plt.plot(range(epochs), val_loss, label='val_loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.figure()
plt.plot(range(epochs), train_accuracy, label='train_accuracy')
plt.plot(range(epochs), val_accuracy, label='val_accuracy')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.show()
```
### 预测代码
```python
def main():
im_height = 224
im_width = 224
img_path = "./1.png"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img = img.resize((im_width, im_height))
plt.imshow(img)
img = np.array(img) / 255.
img = (np.expand_dims(img, 0))
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
model = AlexNet_v2(num_classes=5)
weighs_path = "./save_weights/myAlex.h5"
assert os.path.exists(weighs_path), "file: '{}' dose not exist.".format(weighs_path)
model.build((32, 224, 224, 3))
model.load_weights(weighs_path)
result = np.squeeze(model.predict(img))
predict_class = np.argmax(result)
print_res = "Types of flowers: {} Probability: {:.3f}".format(class_indict[str(predict_class)], result[predict_class])
plt.title(print_res)
for i in range(len(result)):
print("Types of flowers: {:10} Probability: {:.3f}".format(class_indict[str(i)], result[i]))
plt.show()
```
阅读全文