【庖丁解牛】從零實現RetinaNet(九):使用COCO預訓練權重在VOC上訓練RetinaNet

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

如何載入COCO預訓練權重

由於其他數據集如VOC的類別數不一定和COCO數據集相同,載入COCO預訓練權重後要先去掉和類別有關的卷積層權重。我們重新定義一下RetinaNet.py:

import os
import sys
import numpy as np

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

from public.path import pretrained_models_path
from public.detection.models.retinanet import RetinaNet

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    'resnet18_retinanet',
    'resnet34_retinanet',
    'resnet50_retinanet',
    'resnet101_retinanet',
    'resnet152_retinanet',
]

model_urls = {
    'resnet18_retinanet': 'empty',
    'resnet34_retinanet': 'empty',
    'resnet50_retinanet': '{}/detection_models/resnet50_retinanet_map_0.286.pth'.format(pretrained_models_path),
    'resnet101_retinanet': 'empty',
    'resnet152_retinanet': 'empty',
}


def _retinanet(arch, pretrained, progress, **kwargs):
    model = RetinaNet(arch, **kwargs)

    pretrained_models = torch.load(model_urls[arch + "_retinanet"],
                                   map_location=torch.device('cpu'))
    del pretrained_models['cls_head.cls_head.8.weight']
    del pretrained_models['cls_head.cls_head.8.bias']
    del pretrained_models['reg_head.reg_head.8.weight']
    del pretrained_models['reg_head.reg_head.8.bias']

    # only load state_dict()
    if pretrained:
        model.load_state_dict(pretrained_models, strict=False)

    return model


def resnet18_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet18', pretrained, progress, **kwargs)


def resnet34_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet34', pretrained, progress, **kwargs)


def resnet50_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet50', pretrained, progress, **kwargs)


def resnet101_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet101', pretrained, progress, **kwargs)


def resnet152_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet152', pretrained, progress, **kwargs)


if __name__ == '__main__':
    net = RetinaNet(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, batch_anchors = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])

首先讀取預訓練權重,這是一個字典。然後去除掉ClsHead和RegHead的最後一層卷積層的預訓練權重(嚴格來說這裏RegHead的最後一層卷積層的權重不去掉也沒事,只要每個位置上仍然是9個Anchor即可)。然後把讀取模型,把預訓練權重加載進模型。這裏strict=False表示字典中的鍵和模型層名不完全一致時也可以加載權重,模型的每一層會在字典裏找到和層名一樣的鍵,然後把這個鍵的值(也就是權重)加載到模型裏。

完整訓練測試代碼

config.py:

import os
import sys

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

from public.path import VOCdataset_path
from public.detection.dataset.vocdataset import VocDetection, Resize, RandomFlip, RandomCrop, RandomTranslate

import torchvision.transforms as transforms
import torchvision.datasets as datasets


class Config(object):
    log = './log'  # Path to save log
    checkpoint_path = './checkpoints'  # Path to store checkpoint model
    resume = './checkpoints/latest.pth'  # load checkpoint model
    evaluate = None  # evaluate model path
    dataset_path = VOCdataset_path

    network = "resnet50_retinanet"
    pretrained = True
    num_classes = 20
    seed = 0
    input_image_size = 600

    train_dataset = VocDetection(root_dir=dataset_path,
                                 image_sets=[('2007', 'trainval'),
                                             ('2012', 'trainval')],
                                 transform=transforms.Compose([
                                     RandomFlip(flip_prob=0.5),
                                     RandomCrop(crop_prob=0.5),
                                     RandomTranslate(translate_prob=0.5),
                                     Resize(resize=input_image_size),
                                 ]),
                                 keep_difficult=False)
    val_dataset = VocDetection(root_dir=dataset_path,
                               image_sets=[('2007', 'test')],
                               transform=transforms.Compose([
                                   Resize(resize=input_image_size),
                               ]),
                               keep_difficult=False)

    epochs = 20
    per_node_batch_size = 20
    lr = 1e-4
    num_workers = 4
    print_interval = 100
    apex = True
    sync_bn = False

由於VOC數據集比較小,這裏我設置訓練20個epoch。pretrained = True表示加載預訓練權重,爲False則不加載預訓練權重。

train.py:

import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')

import numpy as np
from thop import profile
from thop import clever_format
from tqdm import tqdm
import apex
from apex import amp
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import RetinaLoss
from public.detection.models.decode import RetinaDecoder
from retinanet import resnet50_retinanet
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


def parse_args():
    parser = argparse.ArgumentParser(
        description='PyTorch COCO Detection Distributed Training')
    parser.add_argument('--network',
                        type=str,
                        default=Config.network,
                        help='name of network')
    parser.add_argument('--lr',
                        type=float,
                        default=Config.lr,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=Config.epochs,
                        help='num of training epochs')
    parser.add_argument('--per_node_batch_size',
                        type=int,
                        default=Config.per_node_batch_size,
                        help='per_node batch size')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=Config.pretrained,
                        help='load pretrained model params or not')
    parser.add_argument('--num_classes',
                        type=int,
                        default=Config.num_classes,
                        help='model classification num')
    parser.add_argument('--input_image_size',
                        type=int,
                        default=Config.input_image_size,
                        help='input image size')
    parser.add_argument('--num_workers',
                        type=int,
                        default=Config.num_workers,
                        help='number of worker to load data')
    parser.add_argument('--resume',
                        type=str,
                        default=Config.resume,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkpoints',
                        type=str,
                        default=Config.checkpoint_path,
                        help='path for saving trained models')
    parser.add_argument('--log',
                        type=str,
                        default=Config.log,
                        help='path to save log')
    parser.add_argument('--evaluate',
                        type=str,
                        default=Config.evaluate,
                        help='path for evaluate model')
    parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
    parser.add_argument('--print_interval',
                        type=bool,
                        default=Config.print_interval,
                        help='print interval')
    parser.add_argument('--apex',
                        type=bool,
                        default=Config.apex,
                        help='use apex or not')
    parser.add_argument('--sync_bn',
                        type=bool,
                        default=Config.sync_bn,
                        help='use sync bn or not')
    parser.add_argument('--local_rank',
                        type=int,
                        default=0,
                        help='LOCAL_PROCESS_RANK')

    return parser.parse_args()


def compute_voc_ap(recall, precision, use_07_metric=True):
    if use_07_metric:
        # use voc 2007 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                # get max precision  for recall >= t
                p = np.max(precision[recall >= t])
            # average 11 recall point precision
            ap = ap + p / 11.
    else:
        # use voc>=2010 metric,average all different recall precision as ap
        # recall add first value 0. and last value 1.
        mrecall = np.concatenate(([0.], recall, [1.]))
        # precision add first value 0. and last value 0.
        mprecision = np.concatenate(([0.], precision, [0.]))

        # compute the precision envelope
        for i in range(mprecision.size - 1, 0, -1):
            mprecision[i - 1] = np.maximum(mprecision[i - 1], mprecision[i])

        # to calculate area under PR curve, look for points where X axis (recall) changes value
        i = np.where(mrecall[1:] != mrecall[:-1])[0]

        # sum (\Delta recall) * prec
        ap = np.sum((mrecall[i + 1] - mrecall[i]) * mprecision[i + 1])

    return ap


def compute_ious(a, b):
    """
    :param a: [N,(x1,y1,x2,y2)]
    :param b: [M,(x1,y1,x2,y2)]
    :return:  IoU [N,M]
    """

    a = np.expand_dims(a, axis=1)  # [N,1,4]
    b = np.expand_dims(b, axis=0)  # [1,M,4]

    overlap = np.maximum(0.0,
                         np.minimum(a[..., 2:], b[..., 2:]) -
                         np.maximum(a[..., :2], b[..., :2]))  # [N,M,(w,h)]

    overlap = np.prod(overlap, axis=-1)  # [N,M]

    area_a = np.prod(a[..., 2:] - a[..., :2], axis=-1)
    area_b = np.prod(b[..., 2:] - b[..., :2], axis=-1)

    iou = overlap / (area_a + area_b - overlap)

    return iou


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_ap, mAP = evaluate_voc(val_dataset,
                                   model,
                                   decoder,
                                   num_classes=20,
                                   iou_thread=0.5)

    return all_ap, mAP


def evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5):
    preds, gts = [], []
    for index in tqdm(range(len(val_dataset))):
        data = val_dataset[index]
        img, gt_annot, scale = data['img'], data['annot'], data['scale']

        gt_bboxes, gt_classes = gt_annot[:, 0:4], gt_annot[:, 4]
        gt_bboxes /= scale

        gts.append([gt_bboxes, gt_classes])

        cls_heads, reg_heads, batch_anchors = model(img.cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        preds_scores, preds_classes, preds_boxes = decoder(
            cls_heads, reg_heads, batch_anchors)
        preds_scores, preds_classes, preds_boxes = preds_scores.cpu(
        ), preds_classes.cpu(), preds_boxes.cpu()
        preds_boxes /= scale

        # make sure decode batch_size=1
        # preds_scores shape:[1,max_detection_num]
        # preds_classes shape:[1,max_detection_num]
        # preds_bboxes shape[1,max_detection_num,4]
        assert preds_scores.shape[0] == 1

        preds_scores = preds_scores.squeeze(0)
        preds_classes = preds_classes.squeeze(0)
        preds_boxes = preds_boxes.squeeze(0)

        preds_scores = preds_scores[preds_classes > -1]
        preds_boxes = preds_boxes[preds_classes > -1]
        preds_classes = preds_classes[preds_classes > -1]

        preds.append([preds_boxes, preds_classes, preds_scores])

    print("all val sample decode done.")

    all_ap = {}
    for class_index in tqdm(range(num_classes)):
        per_class_gt_boxes = [
            image[0][image[1] == class_index] for image in gts
        ]
        per_class_pred_boxes = [
            image[0][image[1] == class_index] for image in preds
        ]
        per_class_pred_scores = [
            image[2][image[1] == class_index] for image in preds
        ]

        fp = np.zeros((0, ))
        tp = np.zeros((0, ))
        scores = np.zeros((0, ))
        total_gts = 0

        # loop for each sample
        for per_image_gt_boxes, per_image_pred_boxes, per_image_pred_scores in zip(
                per_class_gt_boxes, per_class_pred_boxes,
                per_class_pred_scores):
            total_gts = total_gts + len(per_image_gt_boxes)
            # one gt can only be assigned to one predicted bbox
            assigned_gt = []
            # loop for each predicted bbox
            for index in range(len(per_image_pred_boxes)):
                scores = np.append(scores, per_image_pred_scores[index])
                if per_image_gt_boxes.shape[0] == 0:
                    # if no gts found for the predicted bbox, assign the bbox to fp
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
                    continue
                pred_box = np.expand_dims(per_image_pred_boxes[index], axis=0)
                iou = compute_ious(per_image_gt_boxes, pred_box)
                gt_for_box = np.argmax(iou, axis=0)
                max_overlap = iou[gt_for_box, 0]
                if max_overlap >= iou_thread and gt_for_box not in assigned_gt:
                    fp = np.append(fp, 0)
                    tp = np.append(tp, 1)
                    assigned_gt.append(gt_for_box)
                else:
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
        # sort by score
        indices = np.argsort(-scores)
        fp = fp[indices]
        tp = tp[indices]
        # compute cumulative false positives and true positives
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        # compute recall and precision
        recall = tp / total_gts
        precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = compute_voc_ap(recall, precision)
        all_ap[class_index] = ap

    mAP = 0.
    for _, class_mAP in all_ap.items():
        mAP += float(class_mAP)
    mAP /= num_classes

    return all_ap, mAP


def main():
    args = parse_args()
    global local_rank
    local_rank = args.local_rank
    if local_rank == 0:
        global logger
        logger = get_logger(__name__, args.log)

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    global gpus_num
    gpus_num = torch.cuda.device_count()
    if local_rank == 0:
        logger.info(f'use {gpus_num} gpus')
        logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    if local_rank == 0:
        logger.info('start loading data')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        Config.train_dataset, shuffle=True)
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.per_node_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collater,
                              sampler=train_sampler)
    if local_rank == 0:
        logger.info('finish loading data')

    model = resnet50_retinanet(**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        if local_rank == 0:
            logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    if local_rank == 0:
        logger.info(
            f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = RetinaLoss(image_w=args.input_image_size,
                           image_h=args.input_image_size,alpha=0.25,
                 gamma=1.5).cuda()
    decoder = RetinaDecoder(image_w=args.input_image_size,
                            image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.apex:
        amp.register_float_function(torch, 'sigmoid')
        amp.register_float_function(torch, 'softmax')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)
        if args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
    else:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[local_rank],
                                                    output_device=local_rank)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            if local_rank == 0:
                logger.exception(
                    '{} is not a file, please check it again'.format(
                        args.resume))
            sys.exit(-1)
        if local_rank == 0:
            logger.info('start only evaluating')
            logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        if local_rank == 0:
            all_ap, mAP = validate(Config.val_dataset, model, decoder)
            logger.info(f"eval done.")
            for class_index, class_AP in all_ap.items():
                logger.info(f"class: {class_index},AP: {class_AP:.3f}")
            logger.info(f"mAP: {mAP:.3f}") 

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        if local_rank == 0:
            logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if local_rank == 0:
            logger.info(
                f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
                f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
            )

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    if local_rank == 0:
        logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        cls_losses, reg_losses, losses = train(train_loader, model, criterion,
                                               optimizer, scheduler, epoch,
                                               args)
        if local_rank == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
            )

        if epoch % 5 == 0 or epoch == args.epochs:
            if local_rank == 0:
                all_ap, mAP = validate(Config.val_dataset, model, decoder)
                logger.info(f"eval done.")
                for class_index, class_AP in all_ap.items():
                    logger.info(f"class: {class_index},AP: {class_AP:.3f}")
                logger.info(f"mAP: {mAP:.3f}") 
                if mAP > best_map:
                    torch.save(model.module.state_dict(),
                                os.path.join(args.checkpoints, "best.pth"))
                    best_map = mAP
        if local_rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'best_map': best_map,
                    'cls_loss': cls_losses,
                    'reg_loss': reg_losses,
                    'loss': losses,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(args.checkpoints, 'latest.pth'))

    if local_rank == 0:
        logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    if local_rank == 0:
        logger.info(
            f"finish training, total training time: {training_time:.2f} hours")


def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
    cls_losses, reg_losses, losses = [], [], []

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)
    prefetcher = COCODataPrefetcher(train_loader)
    images, annotations = prefetcher.next()
    iter_index = 1

    while images is not None:
        images, annotations = images.cuda().float(), annotations.cuda()
        cls_heads, reg_heads, batch_anchors = model(images)
        cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,
                                       annotations)
        loss = cls_loss + reg_loss
        if cls_loss == 0.0 or reg_loss == 0.0:
            optimizer.zero_grad()
            continue

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        optimizer.zero_grad()

        cls_losses.append(cls_loss.item())
        reg_losses.append(reg_loss.item())
        losses.append(loss.item())

        images, annotations = prefetcher.next()

        if local_rank == 0 and iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step(np.mean(losses))

    return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)


if __name__ == '__main__':
    main()

最後保存的best.pth是測試中mAP表現最好的模型權重。

實驗結果

Network batch gpu-num apex syncbn epoch5-mAP-loss epoch5-mAP-loss epoch15-mAP-loss epoch20-mAP-loss
ResNet50-RetinaNet 20 1 yes no 0.666,0.60 0.730,0.44 0.739,0.35 0.737,0.30
ResNet50-RetinaNet-cocopre 20 1 yes no 0.790,0.34 0.796,0.26 0.797,0.22 0.795,0.19

測試時使用VOC2007的11 point metric方式計算mAP。帶-cocopre表示初始化後載入了COCO預訓練權重。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章