S_inputs = Input(shape=(11,), dtype='int32') #(None,600) O_seq = Embedding(5000, 128)(S_inputs) #(None,600,128) cnn1 = Conv1D(256, 3, padding='same', strides=1, activation='relu')(O_seq) cnn1 = MaxPooling1D(pool_size=3)(cnn1) cnn = cnn1 O_seq = GlobalAveragePooling1D()(cnn) #(None,128) print(O_seq.shape) O_seq = Dropout(0.9)(O_seq) outputs = Dense(1, activation='tanh',kernel_regularizer = tf.keras.regularizers.L2())(O_seq) model = Model(inputs=S_inputs, outputs=outputs) opt = SGD(learning_rate=0.1, decay=0.00001) loss = 'categorical_crossentropy' model.compile(loss=loss, optimizer=opt, metrics=['categorical_accuracy']) print('Train...') h = model.fit(Xtrain, ytrain,batch_size=batch_size,validation_split = 0.2,epochs=5) plt.plot(h.history["loss"], label="train_loss") plt.plot(h.history["val_loss"], label="test_loss") plt.legend() plt.show()给这段代码加注释
时间: 2023-06-19 21:05:59 浏览: 93
Python库 | switch_inputs-0.7.0-py3-none-any.whl
# 导入模块
from tensorflow.keras.layers import Input, Embedding, Conv1D, MaxPooling1D, GlobalAveragePooling1D, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
import tensorflow as tf
import matplotlib.pyplot as plt
# 定义输入层
S_inputs = Input(shape=(11,), dtype='int32') #(None,600)
# 创建嵌入层
O_seq = Embedding(5000, 128)(S_inputs) #(None,600,128)
# 创建卷积层并进行池化操作
cnn1 = Conv1D(256, 3, padding='same', strides=1, activation='relu')(O_seq)
cnn1 = MaxPooling1D(pool_size=3)(cnn1)
cnn = cnn1
# 全局平均池化
O_seq = GlobalAveragePooling1D()(cnn) #(None,128)
# 添加 dropout 层
O_seq = Dropout(0.9)(O_seq)
# 创建输出层
outputs = Dense(1, activation='tanh',kernel_regularizer = tf.keras.regularizers.L2())(O_seq)
# 定义模型并进行编译
model = Model(inputs=S_inputs, outputs=outputs)
opt = SGD(learning_rate=0.1, decay=0.00001)
loss = 'categorical_crossentropy'
model.compile(loss=loss, optimizer=opt, metrics=['categorical_accuracy'])
# 输出模型结构
model.summary()
# 训练模型
print('Train...')
h = model.fit(Xtrain, ytrain,batch_size=batch_size,validation_split = 0.2,epochs=5)
# 绘制损失函数曲线
plt.plot(h.history["loss"], label="train_loss")
plt.plot(h.history["val_loss"], label="test_loss")
plt.legend()
plt.show()
阅读全文