keras网络可视化代码
时间: 2023-09-22 16:15:27 浏览: 61
以下是使用Keras自带的可视化工具绘制模型图的代码示例:
``` python
from keras.utils.vis_utils import plot_model
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
# 定义模型
inputs = Input(shape=(28, 28, 1))
x = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(inputs)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(units=128, activation='relu')(x)
predictions = Dense(units=10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
# 绘制模型图
plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
```
在终端执行该代码后,会在当前目录下生成一个名为“model.png”的模型图像。其中,`show_shapes=True`表示将每个层的输入和输出形状显示在图中,`show_layer_names=True`表示将每个层的名称显示在图中。