模型訓練技巧——CutMix

論文:https://arxiv.org/pdf/1905.04899v2.pdf

官方代碼:https://github.com/clovaai/CutMix-PyTorch

1. 論文核心

      

Caption

       簡單來講,就是從A圖中隨機截取一個矩形區域,用該矩形區域的像素替換掉B圖中對應的矩形區域,從而形成一張新的組合圖片。同時,把標籤按照一定的比例(矩形區域所佔整張圖的面積)進行線性組合計算損失。

        論文中的表達形式如下:

Caption

         將圖片和標籤進行了線性組合。

 

2. 代碼實現 

def cutmix_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def cutmix_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    assert alpha > 0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]

    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))

    return x, y_a, y_b, lam


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


# 在train函數中做以下修改, 其他地方不做任何修改
    for (inputs, targets) in tqdm(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        r = np.random.rand(1)
        if r < 0.5: # 做cutmix的概率爲0.5
            inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)
            inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b))
            outputs = net(inputs)
            loss = cutmix_criterion(criterion, outputs, targets_a.long(), targets_b.long(), lam)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets.long())

             官方代碼都寫在train函數裏,博主覺得函數過長,於是把核心功能cutmix封裝成函數,看起來更簡潔。

3. 作者實驗

Caption

            從作者的實驗數據來看,CutMix的效果比Mixup和Cutout都好,並且在圖像識別和目標檢測任務上都漲點明顯。

 

end

博主應用該trick正在訓練模型,明天測試結果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章