ACNet代碼解讀

Paper:https://arxiv.org/abs/1908.03930v1%E3%80%82

Code:https://github.com/ShawnDing1994/ACN

前言:本人第一次解讀代碼,才疏學淺,如有紕漏,望不吝賜教!

# 代碼塊 1
# 作者源碼非常詳實,在此採用ResNet56爲backbone解讀
python acnet/acnet_rc56.py --try_arg=acnet_lrs3_warmup_bias

# actnet_rc56.py,主要定義了backbone,學習率,訓練集,存儲位置等
def acnet_rc56():
    try_arg = start_exp() # 接收超參數,卷積的方式、學習率、warmup、bias

    network_type = 'rc56'  # backbone
    dataset_name = 'cifar10'  # 訓練集
    log_dir = 'acnet_exps/{}_{}_train'.format(network_type, try_arg)  # 訓練日誌存儲文件夾
    save_weights = 'acnet_exps/{}_{}_savedweights.pth'.format(network_type, try_arg)  # 模型存儲位置和名字
    weight_decay_strength = 1e-4 
    batch_size = 64   

    # 定義初始學習率、max_epoch、學習率衰減的epoch和衰減權重
    lrs = parse_usual_lr_schedule(try_arg)

    if 'bias' in try_arg:
        weight_decay_bias = weight_decay_strength  # 選擇
    else:
        weight_decay_bias = 0

    if 'warmup' in try_arg:
        warmup_factor = 0  # 選擇
    else:
        warmup_factor = 1

    # 定義基本的config信息
    config = get_baseconfig_by_epoch(network_type=network_type,  # backbone        
                dataset_name=dataset_name, dataset_subset='train',  # 訓練數據
                global_batch_size=batch_size, num_node=1,    # BS和?
                weight_decay=weight_decay_strength,  # ?
                optimizer_type='sgd', momentum=0.9,  # 參數優化方式和動量大小
                max_epochs=lrs.max_epochs, base_lr=lrs.base_lr,# 最大訓練次數和初始學習率
                lr_epoch_boundaries=lrs.lr_epoch_boundaries,# 學習率衰減的具體epoch
                lr_decay_factor=lrs.lr_decay_factor,# 學習率衰減的權重
                # 在前5個epoch裏,採用線性的方式,以warmup_factor(0)的參數升溫學習率
                warmup_epochs=5, warmup_method='linear', warmup_factor=warmup_factor,
                # 每隔20000step保存模型,100step打印日誌,存儲文件夾爲log_dir
                ckpt_iter_period=20000, tb_iter_period=100, output_dir=log_dir,
                # 日誌輸出文件夾,保存權重的文件、每隔2epoch測試一下模型
                tb_dir=log_dir, save_weights=save_weights, val_epoch_period=2,
                linear_final_lr=lrs.linear_final_lr,weight_decay_bias=weight_decay_bias)

    if 'normal' in try_arg:
        builder = None
    elif 'acnet' in try_arg:  # 選擇此路徑
        from acnet.acnet_builder import ACNetBuilder  # 導入ACNet
        builder = ACNetBuilder(base_config=config, deploy=False)  # 轉代碼塊 2
    else:
        assert False

    # 訓練函數 
    ding_train(config, show_variables=True, convbuilder=builder, use_nesterov='nest' in             
                  try_arg) # 轉代碼塊 3

# 代碼塊 2
class ACNetBuilder(ConvBuilder):

    def __init__(self, base_config, deploy):
        super(ACNetBuilder, self).__init__(base_config=base_config)
        self.deploy = deploy  # 當前爲訓練模式 deploy=False

    def switch_to_deploy(self):
        self.deploy = True

    def Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 
               dilation=1, groups=1, bias=True,
               padding_mode='zeros', use_original_conv=False):
        if use_original_conv or kernel_size == 1 or kernel_size == (1, 1):
            return super(ACNetBuilder, self).Conv2d(in_channels=in_channels, 
                         out_channels=out_channels,kernel_size=kernel_size, 
                         stride=stride,padding=padding, dilation=dilation, 
                         groups=groups, bias=bias,padding_mode=padding_mode, 
                         use_original_conv=True)
        else:
            return ACBlock(in_channels, out_channels, kernel_size=kernel_size, 
                         stride=stride,padding=padding, dilation=dilation, 
                         groups=groups, padding_mode=padding_mode,
                         deploy=self.deploy)

    def Conv2dBN(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 
              dilation=1, groups=1,padding_mode='zeros', use_original_conv=False):
        if use_original_conv or kernel_size == 1 or kernel_size == (1, 1):
            return super(ACNetBuilder, self).Conv2dBN(in_channels=in_channels, 
                         out_channels=out_channels,kernel_size=kernel_size, 
                         stride=stride,padding=padding, dilation=dilation, 
                         groups=groups,padding_mode=padding_mode, use_original_conv=True)
        else:
            return ACBlock(in_channels, out_channels, kernel_size=kernel_size, 
                         stride=stride, padding=padding, dilation=dilation, 
                         groups=groups, padding_mode=padding_mode, deploy=self.deploy)

    def Conv2dBNReLU(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 
              dilation=1, groups=1, padding_mode='zeros', use_original_conv=False):
        if use_original_conv or kernel_size == 1 or kernel_size == (1, 1):
            return super(ACNetBuilder, self).Conv2dBNReLU(in_channels=in_channels, 
                         out_channels=out_channels,kernel_size=kernel_size, 
                         stride=stride,padding=padding, dilation=dilation, 
                         groups=groups,padding_mode=padding_mode, use_original_conv=True)
        else:
            se = nn.Sequential()
            se.add_module('acb', ACBlock(in_channels, out_channels, 
                         kernel_size=kernel_size, stride=stride,padding=padding, 
                         dilation=dilation, groups=groups, padding_mode=padding_mode,
                         deploy=self.deploy))
            se.add_module('relu', self.ReLU())
            return se

    def BNReLUConv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 
               dilation=1, groups=1,padding_mode='zeros', use_original_conv=False):
        if use_original_conv or kernel_size == 1 or kernel_size == (1, 1):
            return super(ACNetBuilder, self).BNReLUConv2d(in_channels=in_channels, 
                         out_channels=out_channels,kernel_size=kernel_size, 
                         stride=stride,padding=padding, dilation=dilation, groups=groups,
                         padding_mode=padding_mode, use_original_conv=True)
        bn_layer = self.BatchNorm2d(num_features=in_channels)
        conv_layer = ACBlock(in_channels, out_channels, kernel_size=kernel_size, 
                         stride=stride, padding=padding, dilation=dilation, 
                         groups=groups, padding_mode=padding_mode,
                         deploy=self.deploy)
        se = self.Sequential()
        se.add_module('bn', bn_layer)
        se.add_module('relu', self.ReLU())
        se.add_module('acb', conv_layer)
        return se
# 代碼塊 3
def ding_train(cfg: BaseConfigByEpoch, net=None, train_dataloader=None, 
               val_dataloader=None, show_variables=False,
               convbuilder=None, beginning_msg=None,
               init_hdf5=None, no_l2_keywords=None, gradient_mask=None, 
               use_nesterov=False):
    # LOCAL_RANK = 0
    #
    # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # is_distributed = num_gpus > 1
    #
    # if is_distributed:
    #     torch.cuda.set_device(LOCAL_RANK)
    #     torch.distributed.init_process_group(
    #         backend="nccl", init_method="env://"
    #     )
    #     synchronize()
    #
    # torch.backends.cudnn.benchmark = True

    ensure_dir(cfg.output_dir)  # 覈查輸出文件夾是否存在
    ensure_dir(cfg.tb_dir)      # 覈查日誌輸出文件夾是否存在
    with Engine() as engine:    # 定義torch訓練的輪子

        is_main_process = (engine.world_rank == 0)  # TODO correct?

        # 定義記錄器
        logger = engine.setup_log(
            name='train', log_dir=cfg.output_dir, file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:  # 選擇此路徑 定義網絡框架 轉模塊 4
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()  # 初始化模型,並在cuda上訓練

        if train_dataloader is None:  # 選擇此路徑,定義訓練數據加載器  TODO
            train_dataloader = create_dataset(cfg.dataset_name, cfg.dataset_subset,             
                               cfg.global_batch_size)
        # 選擇此路徑,定義驗證數據加載器
        if cfg.val_epoch_period > 0 and val_dataloader is None: 
            val_dataloader = create_dataset(cfg.dataset_name, 
                               'val', batch_size=100)  

        print('NOTE: Data prepared')
        # 打印數據以及訓練信息 batch_size、GPU的數量和佔據GPU的顯存
        print('NOTE: We have global_batch_size={} on {} GPUs, 
              the allocated GPU memory is {}'.format(cfg.global_batch_size, 
              torch.cuda.device_count(), torch.cuda.memory_allocated()))

        # device = torch.device(cfg.device)
        # model.to(device)
        # model.cuda()

        if no_l2_keywords is None:
            no_l2_keywords = []
        # 定義SGD權重優化器
        optimizer = get_optimizer(cfg, model, no_l2_keywords=no_l2_keywords, 
                                 use_nesterov=use_nesterov)
        # 定義學習率更新器
        scheduler = get_lr_scheduler(cfg, optimizer)
        # 定義交叉熵損失函數 
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(
            scheduler=scheduler, model=model, optimizer=optimizer)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids= 
                    [engine.world_rank],broadcast_buffers=False, )
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights, is_restore=True)  # 權重參數初始化

        if init_hdf5:
            engine.load_hdf5(init_hdf5)  # 加載斷點模型

        if show_variables:
            engine.show_variables()  # 打印模型名和參數

        # ------------ do training ---------------------------- #
        if beginning_msg:
            engine.log(beginning_msg)
        # 打印pytorch的版本
        logger.info("\n\nStart training with pytorch version 
                    {}".format(torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)  # 獲得每一個epoch迭代的step
        max_iters = iters_per_epoch * cfg.max_epochs  # 計算最大的迭代step
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']    # 具體打印的指標

        model.train()   # 訓練模式 BN開啓

        done_epochs = iteration // iters_per_epoch

        engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))  # 保存初始的模型

        recorded_train_time = 0
        recorded_train_examples = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            pbar = tqdm(range(iters_per_epoch))
            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                model.eval() # 測試模式, 不啓用BN
                val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  
                # run_eval函數,就是測試的過程
                eval_dict, _ = run_eval(val_dataloader, val_iters, model, criterion, 
                                       discrip_str,dataset_name=cfg.dataset_name)
                val_top1_value = eval_dict['top1'].item()
                val_top5_value = eval_dict['top5'].item()
                val_loss_value = eval_dict['loss'].item()
                for tag, value in zip(tb_tags, [val_top1_value, val_top5_value, 
                    val_loss_value]):
                    tb_writer.add_scalars(tag, {'Val': value}, iteration)
                engine.log('validate at epoch {}, top1={:.5f}, top5={:.5f}, loss= 
                    {:.6f}'.format(epoch, val_top1_value,val_top5_value,
                    val_loss_value))
                model.train()

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                # 具體每一step的操作 TODO
                acc, acc5, loss = train_one_step(model, data, label, optimizer, 
                                  criterion, if_accum_grad,
                                  gradient_mask_tensor=gradient_mask_tensor)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < 
                    TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()  # 判別學習率更新

                if iteration % cfg.tb_iter_period == 0 and is_main_process:
                    for tag, value in zip(tb_tags, [acc.item(), acc5.item(), 
                        loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and 
                        is_main_process):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                iteration += 1
                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            if iteration >= max_iters:
                break
        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log('TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, 
                total_net_time={:.4f}, examples/sec={}'
                .format(int(TRAIN_SPEED_START * max_iters), int(TRAIN_SPEED_END * 
                max_iters), cfg.global_batch_size,
                recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
class ResNet(nn.Module):
    def __init__(self, builder: ConvBuilder, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.bd = builder
        self.in_planes = 64
        # 調用代碼模塊 2 中ACnet類的方法,與方形卷積操作一致
        self.conv1 = builder.Conv2dBNReLU(3, 64, kernel_size=7, stride=2, padding=3)
        # 此時block爲ACNetBuilder 轉函數self._make_state()
        self.stage1 = self._make_stage(block, 64, num_blocks[0], stride=1)
        self.stage2 = self._make_stage(block, 128, num_blocks[1], stride=2)
        self.stage3 = self._make_stage(block, 256, num_blocks[2], stride=2)
        self.stage4 = self._make_stage(block, 512, num_blocks[3], stride=2)
        self.linear = self.bd.Linear(512 * block.expansion, num_classes)

    def _make_stage(self, block, planes, num_blocks, stride):
        # self.stage1的strides = [3, 3]
        strides = [stride] + [1] * (num_blocks - 1)
        blocks = []
        for stride in strides:
            # 以stage1爲例。block=Bottleneck
            blocks.append(block(builder=self.bd, in_planes=self.in_planes, planes=planes, 
                         stride=stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*blocks)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bd.max_pool2d(out, kernel_size=3, stride=2, padding=1)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.bd.avg_pool2d(out, 7, 1, 0)
        out = self.bd.flatten(out)
        out = self.linear(out)
        return out

周郎有話說:第一次解讀代碼,發現自己還有太多的不足,以後需加強提升。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章