class MyDense(keras.layers.Layer): def __init__(self, units, activation=None, **kwargs): super().__init__(**kwargs) self.units = units self.activation = keras.activations.get(activation) def build(self, batch_input_shape): self.kernel = self.add_weight( name="kernel", shape=[batch_input_shape[-1], self.units], initializer="glorot_normal") self.bias = self.add_weight( name="bias", shape=[self.units], initializer="zeros") super().build(batch_input_shape) # must be at the end def call(self, X): return self.activation(X @ self.kernel + self.bias) def compute_output_shape(self, batch_input_shape): return tf.TensorShape(batch_input_shape.as_list()[:-1] + [self.units]) def get_config(self): base_config = super().get_config() return {**base_config, "units": self.units, "activation": keras.activations.serialize(self.activation)}
时间: 2024-04-06 15:33:02 浏览: 150
tf.keras.layers.Layer自定义层
这段代码定义了一个自定义的Keras层,名为MyDense。该层继承了Keras的Layer类,重写了__init__、build、call等方法。
在__init__方法中,定义了该层的参数units(输出维度)和激活函数activation,同时调用了父类的__init__方法。
在build方法中,定义了该层的权重kernel和偏置bias,权重的初始化方式为glorot_normal,偏置的初始化方式为zeros。同时调用了父类的build方法。
在call方法中,实现了该层的前向传播,即输入X与权重kernel相乘并加上偏置bias,再通过激活函数activation进行激活得到输出。
在compute_output_shape方法中,返回该层的输出shape。
在get_config方法中,返回该层的配置信息,包括units和activation等参数。
阅读全文