优化这段代码:import tensorflow as tf import numpy as np # 加载数据集 mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # 定义模型 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax')]) # 编译模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练模型 model.fit(x_train, y_train, epochs=5) # 生成对抗样本 epsilon = 0.1 x_adv = tf.Variable(x_test[:1000], dtype=tf.float32) y_true = tf.Variable(y_test[:1000], dtype=tf.int64) with tf.GradientTape() as tape: tape.watch(x_adv) y_pred = model(x_adv) loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) grad = tape.gradient(loss, x_adv) normed_grad = tf.sign(grad) x_adv = x_adv + epsilon * normed_grad # 评估模型在对抗样本上的性能 model.evaluate(x_adv, y_true)
时间: 2024-03-07 08:48:39 浏览: 146
下面是优化后的代码:
```python
import tensorflow as tf
import numpy as np
# 加载数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 定义模型
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 生成对抗样本
epsilon = 0.1
x_adv = x_test[:1000]
y_true = y_test[:1000]
# 使用tf.GradientTape计算梯度
def get_grads(model, x, y):
with tf.GradientTape() as tape:
tape.watch(x)
y_pred = model(x)
loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)
return tape.gradient(loss, x)
# 对抗样本生成函数
def generate_adv_example(model, x, y, epsilon):
grad = tf.sign(get_grads(model, x, y))
return x + epsilon * grad
x_adv = generate_adv_example(model, x_adv, y_true, epsilon)
# 评估模型在对抗样本上的性能
model.evaluate(x_adv, y_true)
```
优化后的代码主要做了以下几个方面的改进:
1. 对抗样本生成过程封装成了函数 `generate_adv_example()`,提高了代码的复用性。
2. 使用了 `get_grads()` 函数来计算梯度,避免了重复计算梯度的问题。
3. 去掉了 `tf.Variable` 的使用,避免了不必要的内存开销。
4. 代码风格更加规范,易于阅读。
阅读全文