WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x0000022972E28048> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
时间: 2024-04-17 13:28:57 浏览: 156
这是一个警告信息,提示在过去的6次调用中,函数被重复追踪。这可能是因为在循环中重复创建了@tf.function,传递了不同形状的张量,或者传递了Python对象而不是张量。对于情况(1),请将@tf.function定义在循环外部。对于情况(2),@tf.function有一个实验性的experimental_relax_shapes=True选项,可以放宽参数形状以避免不必要的重追踪。对于情况(3),请参考链接中的文档了解更多细节。
相关问题
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x000001B6B7A85EE0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing
这个警告通常是由于在循环中定义 `@tf.function` 导致的。当你在循环中定义 `@tf.function` 时,每次迭代都会重新追踪该函数,这会导致重复的追踪,降低了性能。为了避免这个问题,你可以将 `@tf.function` 定义在循环外面,以便只追踪一次。
如果你确信循环内的计算图和循环外的计算图是相同的,你可以通过将 `experimental_relax_shapes=True` 传递给 `@tf.function` 来避免不必要的追踪。这个选项告诉 TensorFlow 在输入形状稍有不同的情况下仍然使用之前追踪的计算图。
下面是一个示例代码,演示了如何在循环中使用 `@tf.function` 并避免追踪警告:
```
import tensorflow as tf
@tf.function(experimental_relax_shapes=True)
def my_func(x):
# 这里定义你的计算图
return x
for i in range(10):
x = tf.ones((2, 2))
y = my_func(x)
```
在这个例子中,`my_func` 是一个使用 `@tf.function` 定义的函数,它将输入 `x` 作为参数并返回 `x`。`for` 循环中多次调用 `my_func`,但由于使用了 `experimental_relax_shapes=True`,所以不会出现追踪警告。
这段代码def slide2(y): l = y.shape[0] y = tf.reshape(y, [1, -1, 1]) input = keras.Input(shape=(l, 1)) output = Conv1D(filters=1, kernel_size=3, padding='causal', activation='linear', trainable=False, use_bias=False)(input) output = tf.stop_gradient(output) # 将 output 视为常数 model = keras.Model(inputs=input, outputs=output) weights_list = model.get_weights() weights = np.ones(3) / 3 weights_list[0] = weights.reshape((3, 1, 1)) model.set_weights(weights_list) result = model.predict(y) result = tf.reshape(result, [-1, 1]) return result用在了循环内部,出现了WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x000001B6B7A85EE0> triggered tf.function retracing.警告
这个警告通常是由于在循环中定义 `keras.Model.predict` 函数时触发的,原因是每次调用 `predict` 函数时都需要重新追踪计算图,这会导致性能下降。为了避免这个问题,你可以尝试将 `keras.Model.predict` 函数的定义放到循环外面。
具体而言,你可以先定义一个空的模型,然后在循环内部多次调用 `model` 的 `__call__` 方法,这样就可以避免重复追踪计算图。下面是修改后的代码示例:
```
def slide2(y):
l = y.shape[0]
y = tf.reshape(y, [1, -1, 1])
input = keras.Input(shape=(l, 1))
output = Conv1D(filters=1, kernel_size=3, padding='causal', activation='linear', trainable=False, use_bias=False)(input)
output = tf.stop_gradient(output) # 将 output 视为常数
model = keras.Model(inputs=input, outputs=output)
weights_list = model.get_weights()
weights = np.ones(3) / 3
weights_list[0] = weights.reshape((3, 1, 1))
model.set_weights(weights_list)
# 在循环外定义 result
result = np.zeros((y.shape[1], 1))
for i in range(y.shape[1]):
x = y[:, :i+1, :]
y_pred = model(x)
result[i, :] = y_pred[-1, 0]
return result
```
在这个例子中,我们在循环外定义了 `result`,并在循环内部多次调用了 `model` 的 `__call__` 方法,最终得到了和原始代码相同的结果。这样就避免了重复追踪计算图的问题,也就不会再出现 `tf.function retracing` 的警告了。
阅读全文