keras 多输入多输出实验,融合层

官方文档虽然有多输入多输出的例子[英文][译文],但是作为使用者,对于keras多输入多输出存在一定疑惑

1 网络层能不能间隔使用,也就是生成Deep Residual Learning。

2 网络连接的时候,merge层链接,能不能自定义merge网络?

merge子类网络层有:add、Subtract、Multiply、Average、Maximum、Minimum、Concatenate、Dot这九个网络层

merge源代码在github可查看

先分析merge父类代码:

  1 class _Merge(Layer):
  2     """Generic merge layer for elementwise merge functions.
  3     Used to implement `Sum`, `Average`, etc.
  4     # Arguments
  5         **kwargs: standard layer keyword arguments.
  6     """
  7 
  8     def __init__(self, **kwargs):
  9         super(_Merge, self).__init__(**kwargs)
 10         self.supports_masking = True
 11 
 12     def _merge_function(self, inputs):
 13         raise NotImplementedError
 14 
 15     def _compute_elemwise_op_output_shape(self, shape1, shape2):
 16         """Computes the shape of the resultant of an elementwise operation.
 17         # Arguments
 18             shape1: tuple or None. Shape of the first tensor
 19             shape2: tuple or None. Shape of the second tensor
 20         # Returns
 21             expected output shape when an element-wise operation is
 22             carried out on 2 tensors with shapes shape1 and shape2.
 23             tuple or None.
 24         # Raises
 25             ValueError: if shape1 and shape2 are not compatible for
 26                 element-wise operations.
 27         """
 28         if None in [shape1, shape2]:
 29             return None
 30         elif len(shape1) < len(shape2):
 31             return self._compute_elemwise_op_output_shape(shape2, shape1)
 32         elif len(shape2) == 0:
 33             return shape1
 34         output_shape = list(shape1[:-len(shape2)])
 35         for i, j in zip(shape1[-len(shape2):], shape2):
 36             if i is None or j is None:
 37                 output_shape.append(None)
 38             elif i == 1:
 39                 output_shape.append(j)
 40             elif j == 1:
 41                 output_shape.append(i)
 42             else:
 43                 if i != j:
 44                     raise ValueError('Operands could not be broadcast '
 45                                      'together with shapes ' +
 46                                      str(shape1) + ' ' + str(shape2))
 47                 output_shape.append(i)
 48         return tuple(output_shape)
 49 
 50     def build(self, input_shape):
 51         # Used purely for shape validation.
 52         if not isinstance(input_shape, list):
 53             raise ValueError('A merge layer should be called '
 54                              'on a list of inputs.')
 55         if len(input_shape) < 2:
 56             raise ValueError('A merge layer should be called '
 57                              'on a list of at least 2 inputs. '
 58                              'Got ' + str(len(input_shape)) + ' inputs.')
 59         batch_sizes = [s[0] for s in input_shape if s is not None]
 60         batch_sizes = set(batch_sizes)
 61         batch_sizes -= set([None])
 62         if len(batch_sizes) > 1:
 63             raise ValueError('Can not merge tensors with different '
 64                              'batch sizes. Got tensors with shapes : ' +
 65                              str(input_shape))
 66         if input_shape[0] is None:
 67             output_shape = None
 68         else:
 69             output_shape = input_shape[0][1:]
 70         for i in range(1, len(input_shape)):
 71             if input_shape[i] is None:
 72                 shape = None
 73             else:
 74                 shape = input_shape[i][1:]
 75             output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
 76         # If the inputs have different ranks, we have to reshape them
 77         # to make them broadcastable.
 78         if None not in input_shape and len(set(map(len, input_shape))) == 1:
 79             self._reshape_required = False
 80         else:
 81             self._reshape_required = True
 82 
 83     def call(self, inputs):
 84 #返回函数
 85         if self._reshape_required:
 86             reshaped_inputs = []
 87             input_ndims = list(map(K.ndim, inputs))
 88             if None not in input_ndims:
 89                 # If ranks of all inputs are available,
 90                 # we simply expand each of them at axis=1
 91                 # until all of them have the same rank.
 92                 max_ndim = max(input_ndims)
 93                 for x in inputs:
 94                     x_ndim = K.ndim(x)
 95                     for _ in range(max_ndim - x_ndim):
 96                         x = K.expand_dims(x, 1)
 97                     reshaped_inputs.append(x)
 98                 return self._merge_function(reshaped_inputs)
 99             else:
100                 # Transpose all inputs so that batch size is the last dimension.
101                 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
102                 transposed = False
103                 for x in inputs:
104                     x_ndim = K.ndim(x)
105                     if x_ndim is None:
106                         x_shape = K.shape(x)
107                         batch_size = x_shape[0]
108                         new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
109                         x_transposed = K.reshape(x, K.stack([batch_size, K.prod(x_shape[1:])]))
110                         x_transposed = K.permute_dimensions(x_transposed, (1, 0))
111                         x_transposed = K.reshape(x_transposed, new_shape)
112                         reshaped_inputs.append(x_transposed)
113                         transposed = True
114                     elif x_ndim > 1:
115                         dims = list(range(1, x_ndim)) + [0]
116                         reshaped_inputs.append(K.permute_dimensions(x, dims))
117                         transposed = True
118                     else:
119                         # We don't transpose inputs if they are 1D vectors or scalars.
120                         reshaped_inputs.append(x)
121                 y = self._merge_function(reshaped_inputs)
122                 y_ndim = K.ndim(y)
123                 if transposed:
124                     # If inputs have been transposed, we have to transpose the output too.
125                     if y_ndim is None:
126                         y_shape = K.shape(y)
127                         y_ndim = K.shape(y_shape)[0]
128                         batch_size = y_shape[y_ndim - 1]
129                         new_shape = K.concatenate([K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
130                         y = K.reshape(y, (-1, batch_size))
131                         y = K.permute_dimensions(y, (1, 0))
132                         y = K.reshape(y, new_shape)
133                     elif y_ndim > 1:
134                         dims = [y_ndim - 1] + list(range(y_ndim - 1))
135                         y = K.permute_dimensions(y, dims)
136                 return y
137         else:
138             return self._merge_function(inputs)
139 
140     def compute_output_shape(self, input_shape):
141 #返回值的shape设置
142         if input_shape[0] is None:
143             output_shape = None
144         else:
145             output_shape = input_shape[0][1:]
146         for i in range(1, len(input_shape)):
147             if input_shape[i] is None:
148                 shape = None
149             else:
150                 shape = input_shape[i][1:]
151             output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
152         batch_sizes = [s[0] for s in input_shape if s is not None]
153         batch_sizes = set(batch_sizes)
154         batch_sizes -= set([None])
155         if len(batch_sizes) == 1:
156             output_shape = (list(batch_sizes)[0],) + output_shape
157         else:
158             output_shape = (None,) + output_shape
159         return output_shape
160 
161     def compute_mask(self, inputs, mask=None):
162         if mask is None:
163             return None
164         if not isinstance(mask, list):
165             raise ValueError('`mask` should be a list.')
166         if not isinstance(inputs, list):
167             raise ValueError('`inputs` should be a list.')
168         if len(mask) != len(inputs):
169             raise ValueError('The lists `inputs` and `mask` '
170                              'should have the same length.')
171         if all([m is None for m in mask]):
172             return None
173         masks = [K.expand_dims(m, 0) for m in mask if m is not None]
174         return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)

merge父类中调用各类子类层的函数,其实就是直接实例化子类:

def add(inputs, **kwargs):
    """Functional interface to the `Add` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the sum of the inputs.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        added = keras.layers.add([x1, x2])
        out = keras.layers.Dense(4)(added)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
    return Add(**kwargs)(inputs)


def subtract(inputs, **kwargs):
    """Functional interface to the `Subtract` layer.
    # Arguments
        inputs: A list of input tensors (exactly 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the difference of the inputs.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        subtracted = keras.layers.subtract([x1, x2])
        out = keras.layers.Dense(4)(subtracted)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
    return Subtract(**kwargs)(inputs)


def multiply(inputs, **kwargs):
    """Functional interface to the `Multiply` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise product of the inputs.
    """
    return Multiply(**kwargs)(inputs)


def average(inputs, **kwargs):
    """Functional interface to the `Average` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the average of the inputs.
    """
    return Average(**kwargs)(inputs)


def maximum(inputs, **kwargs):
    """Functional interface to the `Maximum` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise maximum of the inputs.
    """
    return Maximum(**kwargs)(inputs)


def minimum(inputs, **kwargs):
    """Functional interface to the `Minimum` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise minimum of the inputs.
    """
    return Minimum(**kwargs)(inputs)


def concatenate(inputs, axis=-1, **kwargs):
    """Functional interface to the `Concatenate` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        axis: Concatenation axis.
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the concatenation of the inputs alongside axis `axis`.
    """
    return Concatenate(axis=axis, **kwargs)(inputs)


def dot(inputs, axes, normalize=False, **kwargs):
    """Functional interface to the `Dot` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        axes: Integer or tuple of integers,
            axis or axes along which to take the dot product.
        normalize: Whether to L2-normalize samples along the
            dot product axis before taking the dot product.
            If set to True, then the output of the dot product
            is the cosine proximity between the two samples.
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the dot product of the samples from the inputs.
    """
    return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)

简单的子层,只需要重载_merge_function,其它函数继承父类

Add层:

class Add(_Merge):
    """Layer that adds a list of inputs.
    It takes as input a list of tensors,
    all of the same shape, and returns
    a single tensor (also of the same shape).
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        added = keras.layers.Add()([x1, x2])  # equivalent to added = keras.layers.add([x1, x2])
        out = keras.layers.Dense(4)(added)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
#把所有输入都与第一个输入相加,意味着你可以使用两个以上的网络层输入……
    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output += inputs[i]
        return output

Subtract层:

class Subtract(_Merge):
    """Layer that subtracts two inputs.
    It takes as input a list of tensors of size 2,
    both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
    also of the same shape.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        # Equivalent to subtracted = keras.layers.subtract([x1, x2])
        subtracted = keras.layers.Subtract()([x1, x2])
        out = keras.layers.Dense(4)(subtracted)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
#输入的层数只能为两个,第一个层减去第二个层
    def _merge_function(self, inputs):
        if len(inputs) != 2:
            raise ValueError('`Subtract` layer should be called '
                             'on exactly 2 inputs')
        if inputs[0]._keras_shape != inputs[1]._keras_shape:
            raise ValueError('`Subtract` layer should be called '
                             'on inputs of the same shape')
        return inputs[0] - inputs[1]

Multiply层:

class Multiply(_Merge):
#其他的层都与第一层相乘,合并的层数可以无穷
    """Layer that multiplies (element-wise) a list of inputs.
    It takes as input a list of tensors,
    all of the same shape, and returns
    a single tensor (also of the same shape).
    """

    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output *= inputs[i]
        return output

Average层:多层求平均值

Maximum层:多层中的最大值

Minimum层:多层中的最小值

Concatenate层:

 1 class Concatenate(_Merge):
 2 #由于连接层的复杂性,所以需要自定义,weghts大小,和该层的各个属性。
 3 #根据需要的坐标系,连接网络层
 4     """Layer that concatenates a list of inputs.
 5     It takes as input a list of tensors,
 6     all of the same shape expect for the concatenation axis,
 7     and returns a single tensor, the concatenation of all inputs.
 8     # Arguments
 9         axis: Axis along which to concatenate.
10         **kwargs: standard layer keyword arguments.
11     """
12 
13     def __init__(self, axis=-1, **kwargs):
14         super(Concatenate, self).__init__(**kwargs)
15         self.axis = axis
16         self.supports_masking = True
17 
18     def build(self, input_shape):
19         # Used purely for shape validation.
20         if not isinstance(input_shape, list):
21             raise ValueError('`Concatenate` layer should be called '
22                              'on a list of inputs')
23         if all([shape is None for shape in input_shape]):
24             return
25         reduced_inputs_shapes = [list(shape) for shape in input_shape]
26         shape_set = set()
27         for i in range(len(reduced_inputs_shapes)):
28             del reduced_inputs_shapes[i][self.axis]
29             shape_set.add(tuple(reduced_inputs_shapes[i]))
30         if len(shape_set) > 1:
31             raise ValueError('`Concatenate` layer requires '
32                              'inputs with matching shapes '
33                              'except for the concat axis. '
34                              'Got inputs shapes: %s' % (input_shape))
35 #tensorflow代码实现返回
36     def call(self, inputs):
37         if not isinstance(inputs, list):
38             raise ValueError('A `Concatenate` layer should be called '
39                              'on a list of inputs.')
40         return K.concatenate(inputs, axis=self.axis)
41 #设置该层输出值的shape大小
42     def compute_output_shape(self, input_shape):
43         if not isinstance(input_shape, list):
44             raise ValueError('A `Concatenate` layer should be called '
45                              'on a list of inputs.')
46         input_shapes = input_shape
47         output_shape = list(input_shapes[0])
48         for shape in input_shapes[1:]:
49             if output_shape[self.axis] is None or shape[self.axis] is None:
50                 output_shape[self.axis] = None
51                 break
52             output_shape[self.axis] += shape[self.axis]
53         return tuple(output_shape)
54 #有无mask元素(屏蔽元素)
55     def compute_mask(self, inputs, mask=None):
56         if mask is None:
57             return None
58         if not isinstance(mask, list):
59             raise ValueError('`mask` should be a list.')
60         if not isinstance(inputs, list):
61             raise ValueError('`inputs` should be a list.')
62         if len(mask) != len(inputs):
63             raise ValueError('The lists `inputs` and `mask` '
64                              'should have the same length.')
65         if all([m is None for m in mask]):
66             return None
67         # Make a list of masks while making sure
68         # the dimensionality of each mask
69         # is the same as the corresponding input.
70         masks = []
71         for input_i, mask_i in zip(inputs, mask):
72             if mask_i is None:
73                 # Input is unmasked. Append all 1s to masks,
74                 # but cast it to bool first
75                 masks.append(K.cast(K.ones_like(input_i), 'bool'))
76             elif K.ndim(mask_i) < K.ndim(input_i):
77                 # Mask is smaller than the input, expand it
78                 masks.append(K.expand_dims(mask_i))
79             else:
80                 masks.append(mask_i)
81         concatenated = K.concatenate(masks, axis=self.axis)
82         return K.all(concatenated, axis=-1, keepdims=False)
83 
84     def get_config(self):
85         config = {
86             'axis': self.axis,
87         }
88 #super申明使用父类设置
89         base_config = super(Concatenate, self).get_config()
90         return dict(list(base_config.items()) + list(config.items()))

Dot层:计算向量积,融合的层数为2

  1 class Dot(_Merge):
  2     """Layer that computes a dot product between samples in two tensors.
  3     E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`,
  4     the output will be a tensor of shape `(batch_size, 1)`
  5     where each entry `i` will be the dot product between
  6     `a[i]` and `b[i]`.
  7     # Arguments
  8         axes: Integer or tuple of integers,
  9             axis or axes along which to take the dot product.
 10         normalize: Whether to L2-normalize samples along the
 11             dot product axis before taking the dot product.
 12             If set to True, then the output of the dot product
 13             is the cosine proximity between the two samples.
 14         **kwargs: Standard layer keyword arguments.
 15     """
 16 
 17     def __init__(self, axes, normalize=False, **kwargs):
 18         super(Dot, self).__init__(**kwargs)
 19         if not isinstance(axes, int):
 20             if not isinstance(axes, (list, tuple)):
 21                 raise TypeError('Invalid type for `axes` - '
 22                                 'should be a list or an int.')
 23             if len(axes) != 2:
 24                 raise ValueError('Invalid format for `axes` - '
 25                                  'should contain two elements.')
 26             if not isinstance(axes[0], int) or not isinstance(axes[1], int):
 27                 raise ValueError('Invalid format for `axes` - '
 28                                  'list elements should be "int".')
 29         self.axes = axes
 30         self.normalize = normalize
 31         self.supports_masking = True
 32 
 33     def build(self, input_shape):
 34         # Used purely for shape validation.
 35         if not isinstance(input_shape, list) or len(input_shape) != 2:
 36             raise ValueError('A `Dot` layer should be called '
 37                              'on a list of 2 inputs.')
 38         shape1 = input_shape[0]
 39         shape2 = input_shape[1]
 40         if shape1 is None or shape2 is None:
 41             return
 42         if isinstance(self.axes, int):
 43             if self.axes < 0:
 44                 axes = [self.axes % len(shape1), self.axes % len(shape2)]
 45             else:
 46                 axes = [self.axes] * 2
 47         else:
 48             axes = self.axes
 49         if shape1[axes[0]] != shape2[axes[1]]:
 50             raise ValueError(
 51                 'Dimension incompatibility '
 52                 '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
 53                 'Layer shapes: %s, %s' % (shape1, shape2))
 54 #实现向量积,操作,根据axis,进行操作,具体操作语句为k.batch_dot(x1,x2)
 55     def call(self, inputs):
 56         x1 = inputs[0]
 57         x2 = inputs[1]
 58         if isinstance(self.axes, int):
 59             if self.axes < 0:
 60                 axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
 61             else:
 62                 axes = [self.axes] * 2
 63         else:
 64             axes = []
 65             for i in range(len(self.axes)):
 66                 if self.axes[i] < 0:
 67                     axes.append(self.axes[i] % K.ndim(inputs[i]))
 68                 else:
 69                     axes.append(self.axes[i])
 70         if self.normalize:
 71             x1 = K.l2_normalize(x1, axis=axes[0])
 72             x2 = K.l2_normalize(x2, axis=axes[1])
 73         output = K.batch_dot(x1, x2, axes)
 74         return output
 75 
 76     def compute_output_shape(self, input_shape):
 77         if not isinstance(input_shape, list) or len(input_shape) != 2:
 78             raise ValueError('A `Dot` layer should be called '
 79                              'on a list of 2 inputs.')
 80         shape1 = list(input_shape[0])
 81         shape2 = list(input_shape[1])
 82         if isinstance(self.axes, int):
 83             if self.axes < 0:
 84                 axes = [self.axes % len(shape1), self.axes % len(shape2)]
 85             else:
 86                 axes = [self.axes] * 2
 87         else:
 88             axes = self.axes
 89         shape1.pop(axes[0])
 90         shape2.pop(axes[1])
 91         shape2.pop(0)
 92         output_shape = shape1 + shape2
 93         if len(output_shape) == 1:
 94             output_shape += [1]
 95         return tuple(output_shape)
 96 
 97     def compute_mask(self, inputs, mask=None):
 98         return None
 99 
100     def get_config(self):
101         config = {
102             'axes': self.axes,
103             'normalize': self.normalize,
104         }
105         base_config = super(Dot, self).get_config()
106         return dict(list(base_config.items()) + list(config.items()))

由于知道各个融合成实现的原理,所以能够自定义融合层: