open-nmt參數max_generator_batches

本文關於onmt的一個參數:

【max_generator_batches】

 

該參數被設置爲默認32:

(下圖爲在onmt開源代碼的opt.py參數文件中的默認設置)

help文檔意爲:

max_generator_batches爲一個序列中並行運行生成器的最大的單詞數量。越高越快,但佔用的內存越大。設置爲0禁用。

第一次看到的時候有點懵,反覆確認代碼後,決定將其暫時理解爲模型對於一個輸入做序列輸出時,不再是一條線的按順序生成,而是多條線並行生成,每條線包含32words。

    group.add('--max_generator_batches', '-max_generator_batches',
              type=int, default=32,
              help="Maximum batches of words in a sequence to run "
                   "the generator on in parallel. Higher is faster, but "
                   "uses more memory. Set to 0 to disable.")

 

該參數對應的代碼部分:

train.py中

【shard_size被賦值 = opt.max_generator_batches = 32】

train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt)

shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
loss, batch_stats = self.train_loss( 
                        batch,
                        outputs,
                        attns,
                        normalization=normalization,
                        shard_size=self.shard_size,
                        trunc_start=j,
                        trunc_size=trunc_size) 
# 當shard_size!=0時,loss爲None,在函數內部loss回傳;***!!!
# 當shard_size==0時,loss 有值,在如下代碼中loss回傳。***!!!
if loss is not None:
    self.optim.backward(loss) 

loss.py中

以下代碼是train_loss函數

def build_loss_compute(model, tgt_field, opt, train=True):
    # XXXXX 此處省略部分代碼
    if opt.copy_attn: 
        # 我的代碼走的是這裏
        compute = onmt.modules.CopyGeneratorLossCompute(
            criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength,
            lambda_coverage=opt.lambda_coverage
        ) # (此處爲CopyGeneratorLossCompute 的 init)
    else:
        compute = NMTLossCompute(
            criterion, loss_gen, lambda_coverage=opt.lambda_coverage)
    compute.to(device)
    return compute

copy_generator.py中

在以下代碼中的_compute_loss函數中返回loss

該類的基類爲->NMTLossCompute->LossComputeBase

class CopyGeneratorLossCompute(NMTLossCompute):
    """Copy Generator Loss Computation."""
    def _compute_loss(self, batch, output, target, copy_attn, align,
                      std_attn=None, coverage_attn=None):
        """Compute the loss.
        The args must match :func:`self._make_shard_state()`.
        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map)
        loss = self.criterion(scores, align, target)
        # this block does not depend on the loss value computed above
        # and is used only for stats
        scores_data = collapse_copy_scores(
            self._unbottle(scores.clone(), batch.batch_size),
            batch, self.tgt_vocab, None)
        scores_data = self._bottle(scores_data)
        # this block does not depend on the loss value computed above and is used only for stats
        # Correct target copy token instead of <unk>
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.clone()
        unk = self.criterion.unk_index
        correct_mask = (target_data == unk) & (align != unk)
        offset_align = align[correct_mask] + len(self.tgt_vocab)
        target_data[correct_mask] += offset_align

        # Compute sum of perplexities for stats
        stats = self._stats(loss.sum().clone(), scores_data, target_data)
        loss = loss.sum()
        return loss, stats

然而這個函數是被什麼調用的呢?

接下來查看loss.py中的基類LossComputeBase

可以輕鬆看到,在該基類的系統__call__函數中,調用了上一份代碼中的_compute_loss函數

即有關於shard的主要代碼在以下這裏:

class LossComputeBase(nn.Module):
    """
    Handles sharding next step predictions and accumulating multiple loss computations
    Users can implement their own loss computation strategy by making subclass of this one.  
    Users need to implement the _compute_loss()  and make_shard_state() methods.
    """
    def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_start=0, trunc_size=None): # shard_size默認爲0 
        """Compute the forward loss, possibly in shards in which case this
        method also runs the backward pass and returns ``None`` as the loss value.
 
        Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. 
        Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers. 釋放生成緩衝區所需的內存
        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :output of decoder `[tgt_len x batch x hidden]`
          attns (dict) : `[tgt_len x batch x src_len]`
          shard_size (int) : maximum number of examples in a shard
        Returns:
            A tuple with the loss and a :obj:`onmt.utils.Statistics` instance.
        """
        if trunc_size is None:
            trunc_size = batch.tgt.size(0) - trunc_start
        trunc_range = (trunc_start, trunc_start + trunc_size) 
        shard_state = self._make_shard_state(batch, output, trunc_range, attns)

        if shard_size == 0: # 爲0時,返回loss值
            loss, stats = self._compute_loss(batch, **shard_state)
            return loss / float(normalization), stats

        # shard_size != 0 loss直接回傳計算(backward),不再返回loss值
        batch_stats = onmt.utils.Statistics() 
        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(float(normalization)).backward() # backward 2/2 -------------- loss 2/2 --------------
            batch_stats.update(stats)  
        return None, batch_stats

上述代碼我們可以看到

 

在shard_size == 0時,則只有一次loss.backward()。

在本文上面的train.py的第一份代碼中,我們可以看到

if loss is not None:
    self.optim.backward(loss) 

shard_size == 0時,唯一的loss.backward()在train.py中,這是我們常見的形式【先利用函數算出loss的值,然後loss.backward】

 

在shard_size != 0時,代碼使用到了shard函數,代碼進行了兩次backward。

def shards(state, shard_size, eval_only=False):
    """
    Args:
        shard_size: The maximum size of the shards yielded by the model.
        eval_only: If True, only yield the state, nothing else.
              Otherwise, yield shards.
    Yields:
        Each yielded shard is a dict.
    Side effect:
        After the last shard, this function does back-propagation.
    """
    if eval_only:
        yield filter_shard_state(state)
    else:
        non_none = dict(filter_shard_state(state, shard_size)) # non_none: 由state dictionary中值非None組成的subdict.
        # non_none是一個sequences of tensor-like的字典,但我們需要一序列的dictionaries of tensors。首先,將字典解壓縮成一個鍵序列和一個tensor-like 序列。
        keys, values = zip(*((k, [v_chunk for v_chunk in v_split])
                             for k, (_, v_split) in non_none.items()))
        # 爲each shard生成一個字典。keys是一樣的。
        # values is a sequence of length #keys 
        # where each element is a sequence of length #shards. 
        # 我們希望遍歷shard,而不是keys,因此,需要按照shard對values進行重新壓縮,這樣每個shard可以與keys匹配。
        for shard_tensors in zip(*values):
            yield dict(zip(keys, shard_tensors))
        # Assumed backprop'd
        variables = []
        for k, (v, v_split) in non_none.items():
            if isinstance(v, torch.Tensor) and state[k].requires_grad: 
                variables.extend(zip(torch.split(state[k], shard_size), [v_chunk.grad for v_chunk in v_split]))
        inputs, grads = zip(*variables) # inputs : tuple 
        torch.autograd.backward(inputs, grads) # backward 1/2 -------------- loss 1/2 -------------- 

在shards函數中,使用yield生成了切片後的數據並送給call函數後,在shards函數內部,對所有requires_grad的Tensor進行了torch.autograd.backward(inputs, grads)。

然後在call函數中,對每一份切片(shard)後數據計算出的loss還有一次 loss.div(float(normalization)).backward();

我們可以看到,在shard_size != 0時,代碼進行了兩次backward。

其中

shards函數 使用到了如下函數代碼,該代碼的作用是將輸入state劃分爲k和v,然後對v進行切片

def filter_shard_state(state, shard_size=None):
    for k, v in state.items():
        if shard_size is None:
            yield k, v

        if v is not None:
            v_split = []
            if isinstance(v, torch.Tensor):
                for v_chunk in torch.split(v, shard_size):
                    v_chunk = v_chunk.data.clone()
                    v_chunk.requires_grad = v.requires_grad
                    v_split.append(v_chunk)
            yield k, (v, v_split)

爲什麼要兩次backward呢?

因爲這行代碼 v_chunk = v_chunk.data.clone()

.data:獲取 Variable 的 內部 Tensor,並脫離計算圖,求導時輸出錯誤結果0,但是不會報錯。

.clone():複製一個完全一樣的Tensor並添加在計算圖中,不脫離計算圖

filter_shard_state將loss函數的輸入.data進行clone,所以loss.backward()計算梯度,只能傳播到loss函數的輸入,只能傳播到這裏,而不是整個計算圖。

而torch.autograd.backward使用之前backward函數計算梯度,以及整個計算圖的輸入,來計算剩餘變量的梯度。

 

爲什麼v_chunk = v_chunk.data.clone()之後又要v_chunk.requires_grad = v.requires_grad呢?

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章