,一tensorflow2.0 - 自定义layer

最近在用tensorflow2.0搭建一个简单的神经网络,虽然结构简单但是由于对自定义有要求,官方提供的layer和model不能满足要求,因此需要自行对layer、model、loss function进行自定义。由于tensorflow2.0发布不久,国内相关文章较少,我便决定写上这一系列文章。本文讨论tensorflow2.0中如何自定义layer。

(一)tensorflow2.0 - 自定义layer

(二)tensorflow2.0 - 自定义Model

(三)tensorflow2.0 - 自定义loss function(损失函数)

(四)tensorflow2.0 - 实战稀疏自动编码器SAE


本文不讨论tensorflow1和2在版本上自定义layer的区别,只讲述2.0版本下如何自定义layer。

本文架构上不做长篇大论,直接根据代码来解释如何自定义模型。

首先引入相应的库函数

import tensorflow as tf
from tensorflow.keras import *

然后自定义Layer类,这里命名为SAELayer,继承自tensorflow.keras.layers.Layer,由于上面引入的库函数为from tensorflow.keras import *,所以写起来就比较清爽,可以直接简写为layers.Layer,之后的都如此,写法上tensorflow.keras都省略了,就不做多解释。

需要注意,Layer类中涉及到了三个重要的方法,分别是__init__()build()call(),关于他们的关系与作用请看我的另一篇文章(tensorflow2.0中Layer的__init__(),build(), call()函数)。这里只简单说明,__init__()函数在创建Layer对象时调用,build在第一次调用call前调用(只调用一次),往后使用Layer的方法都是使用call()的方法。

需要注意build()方法的参数,该方法是被自动调用的,所以其参数是固定的(当然改形参名称没关系),但是不能添加或者删除参数。而call()方法的官方定义为Layer.call(inputs, **kwargs),因此它至少需要一个input作为参数(输入该层的数据),其他参数可以按需自定义

下例为进行一个简单的sigmoid(w*x + b)的功能的自定义层,当然这是一次对一批数据进行操作,所以需要用矩阵(张量)的方式来思考。

class SAELayer(layers.Layer):
        # 初始化num_outputs,即当前层输出元素的个数
    def __init__(self, num_outputs):
        super(SAELayer, self).__init__()
        self.num_outputs = num_outputs
        
        # 在第一次调用该Layer的call方法前(自动)调用该函数,可以知道输入数据的shape
        # 根据输入数据的shape可以初始化权值、bias的矩阵
    def build(self, input_shape):
        self.kernel = self.add_variable("kernel",
                                        shape=[int(input_shape[-1]),
                                               self.num_outputs])
        self.bias = self.add_variable("bias",
                                      shape=[self.num_outputs])
    def call(self, input):
        output = tf.matmul(input, self.kernel) + self.bias
        # sigmoid激活函数
        output = tf.nn.sigmoid(output)
        return output

到此Layer就定义好了,大家可以根据需要对其各部分进行修改,比如在build()中增删参数、在call()中更改计算方式、激活函数等等。

Layer定义好了,如何使用呢?

大可以按照正常使用其他Layer的方式来调用,如果想看具体实例,可以看下一篇文章,里面将Layer放入了一个简单的自定义Model中进行使用。

(二)tensorflow2.0 - 自定义Model


参考文献: