代碼系列——keras.layers.Dot()解析

Dot類在keras的Merge中,根據Keras中文文檔:https://keras-cn.readthedocs.io/en/latest/layers/merge/

Merge層提供了一系列用於融合兩個層或張量的層對象或方法,以大寫首字母開頭的是Layer類,以小寫字母開頭的是張量的函數,張量函數內部其實調用了大寫字母開頭的層。Dot類源代碼如下:

https://github.com/keras-team/keras/blob/61052bc1f1c141c5dba9f83a4af14322ec4e6d7c/keras/layers/merge.py#L494

先注意一點:keras中在一個Model中對張量操作時不考慮batch,都是基於單個樣本的,比如定義輸入張量時input=Input(H, W, C)並沒有batch_size。

首先看一下類的定義:Dot層是計算兩個張量中的樣本間的點乘,因此兩個基於batch的shape爲(bacth_size, n)的張量a和b經過Dot層後輸出的shape爲(batch_size, 1),說明是batch中每一組對應的樣本(也就是a[i]和b[i]這兩個長度爲n的向量)之間進行點乘。

class Dot(_Merge):

    """Layer that computes a dot product between samples in two tensors.


    E.g. if applied to a list of two tensors `a` and `b` of shape `(batch_size, n)`,

    the output will be a tensor of shape `(batch_size, 1)`

    where each entry `i` will be the dot product between

    `a[i]` and `b[i]`.

再看一下構造器內容:輸入參數axes表示進行點乘的軸,可以是一個整數(兩個樣本要進行點乘的是同一個軸)也可以是整數的tuple/list(兩個樣本要進行點乘的軸不同);normalize表示是否在點乘前對axes進行L2-normalize,如果進行正則化(也就是樣本在該軸上的向量元素平方和爲1也就是單位向量),結果可以表示兩樣本間的餘弦相似性。構造器中對axes進行了判斷:是否爲整數,是否爲tuple或list。

   # Arguments

        axes: Integer or tuple of integers,

            axis or axes along which to take the dot product.

        normalize: Whether to L2-normalize samples along the

            dot product axis before taking the dot product.

            If set to True, then the output of the dot product

            is the cosine proximity between the two samples.

        **kwargs: Standard layer keyword arguments.

    """

    def __init__(self, axes, normalize=False, **kwargs):

        super(Dot, self).__init__(**kwargs)

        if not isinstance(axes, int):

            if not isinstance(axes, (list, tuple)):

                raise TypeError('Invalid type for `axes` - '

                                'should be a list or an int.')

            if len(axes) != 2:

                raise ValueError('Invalid format for `axes` - '

                                 'should contain two elements.')

            if not isinstance(axes[0], int) or not isinstance(axes[1], int):

                raise ValueError('Invalid format for `axes` - '

                                 'list elements should be "int".')

        self.axes = axes

        self.normalize = normalize

        self.supports_masking = True

        self._reshape_required = False

再看build()函數,主要確定了shape信息:首先保證輸入的張量是兩個,其次把axes轉換爲長度爲2的list並且兩個張量在axes表示的軸上長度相同(才能進行點乘)。代碼如下:

    def build(self, input_shape):

        # Used purely for shape validation.

        if not isinstance(input_shape, list) or len(input_shape) != 2:

            raise ValueError('A `Dot` layer should be called '

                             'on a list of 2 inputs.')

        shape1 = input_shape[0]

        shape2 = input_shape[1]

        if shape1 is None or shape2 is None:

            return

        if isinstance(self.axes, int):

            if self.axes < 0:

                axes = [self.axes % len(shape1), self.axes % len(shape2)]

            else:

                axes = [self.axes] * 2

        else:

            axes = self.axes

        if shape1[axes[0]] != shape2[axes[1]]:

            raise ValueError(

                'Dimension incompatibility '

                '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +

                'Layer shapes: %s, %s' % (shape1, shape2))

再看_merge_function(),重要的一句就是output = K.batch_dot(x1, x2, axes),具體看一下keras/backend/tensorflow_backend.py中的batch_dot(),主要看代碼中的註釋部分:batchwise dot product。

def batch_dot(x, y, axes=None):

    """Batchwise dot product.


    `batch_dot` is used to compute dot product of `x` and `y` when

    `x` and `y` are data in batches, i.e. in a shape of

    `(batch_size, :)`.

    `batch_dot` results in a tensor or variable with less dimensions

    than the input. If the number of dimensions is reduced to 1,

    we use `expand_dims` to make sure that ndim is at least 2.


    # Arguments

        x: Keras tensor or variable with `ndim >= 2`.

        y: Keras tensor or variable with `ndim >= 2`.

        axes: int or tuple(int, int). Target dimensions to be reduced.


    # Returns

        A tensor with shape equal to the concatenation of `x`'s shape

        (less the dimension that was summed over) and `y`'s shape

        (less the batch dimension and the dimension that was summed over).

        If the final rank is 1, we reshape it to `(batch_size, 1)`.


    # Examples

        Assume `x = [[1, 2], [3, 4]]` and `y = [[5, 6], [7, 8]]`

        `batch_dot(x, y, axes=1) = [[17], [53]]` which is the main diagonal

        of `x.dot(y.T)`, although we never have to calculate the off-diagonal

        elements.


        Pseudocode:

        ```

        inner_products = []

        for xi, yi in zip(x, y):

            inner_products.append(xi.dot(yi))

        result = stack(inner_products)

        ```


        Shape inference:

        Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.

        If `axes` is (1, 2), to find the output shape of resultant tensor,

            loop through each dimension in `x`'s shape and `y`'s shape:


        * `x.shape[0]` : 100 : append to output shape

        * `x.shape[1]` : 20 : do not append to output shape,

            dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)

        * `y.shape[0]` : 100 : do not append to output shape,

            always ignore first dimension of `y`

        * `y.shape[1]` : 30 : append to output shape

        * `y.shape[2]` : 20 : do not append to output shape,

            dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)

        `output_shape` = `(100, 30)`

輸入是data in batches,點乘是batch內的samples間進行的,輸出維度減少的張量,如果輸出維度爲1,要expand_dims確保至少爲2維。輸入張量也必須是2維以上,axes是整數或者tuple(之前在dot中前部分已經處理好了),axes就是要進行點乘的維度,也就是計算後減少/消失的那個維度。輸出張量的shape是兩個輸入張量的shape進行concatenate得到的:但把a張量的axes維度去除了,b張量的batch維度和axes維度去除了。看下面的例子,a.shape=(100, 20),b.shape=(100, 30, 20),axes=(1, 2),則張量a的維度1和b的維度2進行點乘,輸出張量output.shape=(100, 30),計算過程如下:

首先output.shape先加入a.shape中第一個元素100(即output.shape=(100, )),然後第二個元素20就是axes=1上的shape所以點乘後會消失,不計入output.shape,張量a就到此結束;張量b.shape的第一個元素100是batch維度不計入(此時output.shape=(100, )),然後第二個元素30要concatenate到output.shape中(此時output.shape=(100, 30)),第三個元素20就是axes=2的shape點乘後也會消失,不計入output.shape,所以最終output.shape=(100, 30)。

shape的問題解決了,但是具體計算是怎樣進行的呢?在註釋 #Examples 中:

batch_dot((x, y))類似於x.dot(y.T),但只在main diagonal上乘,其實就是不像普通矩陣乘中a[i]和b[:]都乘一遍,而是僅和b[i]點乘,就像上面代碼中pseudocode所寫,x與y的zip,都取出axes軸上的對應向量,然後做點乘,x_{i}僅和y_{i}乘,不會和y_{j}乘。

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章