pytorch實用筆記系列1

打印訓練輸出

from termcolor import cprint
def print_train_log(info_list):
    """
    print trainning log infos
    """
    cprint('{} Epoch: {}/{} [{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(*info_list), 'yellow', attrs=['bold'])

def print_info(info, _type=None):
	"""
	print log
	"""
    if _type is not None:
        if isinstance(info, str):
            cprint(info, _type[0], attrs=[_type[1]])
        elif isinstance(info, list):
            for i in range(info):
                cprint(i, _type[0], attrs=[_type[1]])
    else:
        print(info)

調用方式:

info_list=['Train peleeNet', epoch, epochs, i_batch, len(train_loader), 
           100. * i_batch / len(train_loader), loss.item()]
print_train_log(info_list)

or

print_info(imgs_result_path, ['yellow', 'bold'])

設置日誌記錄

from logging import handlers
from colorlog import ColoredFormatter

def setup_logger(log_filename, level=logging.DEBUG, when='midnight', back_count=0):
    """
    :brief  日誌記錄
    :param log_filename: 日誌名稱
    :param level: 日誌等級
    :param when: 間隔時間:
        S:秒
        M:分
        H:小時
        D:天
        W:每星期(interval==0時代表星期一)
        midnight: 每天凌晨
    :param back_count: 備份文件的個數,若超過該值,就會自動刪除
    :return: logger
    """
    logger = logging.getLogger(log_filename)
    logger.setLevel(level)
    log_file_path = os.path.join(log_filename)
    # set logger formater
    formatter = ColoredFormatter(
        "%(asctime)s %(log_color)s%(levelname)s %(reset)s %(filename)s[%(lineno)d]: %(message)s",
        datefmt='%Y-%m-%d %H:%M:%S', reset=True,
        log_colors={'DEBUG': 'blue', 'INFO': 'green', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red'})
    # output to console
    ch = logging.StreamHandler()
    ch.setLevel(level)
    # output to file
    fh = logging.handlers.TimedRotatingFileHandler(filename=log_file_path,when=when, backupCount=back_count, encoding='utf-8')
    fh.setLevel(level)
    # set logger formater
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    # add to logger
    logger.addHandler(fh)
    logger.addHandler(ch)
    
    return logger

python運行傳參設置

def parse_args():
    '''
    @function name:
    @description:
    @args:
    @return:
    '''
    parser = argparse.ArgumentParser(description='Pelee Training')
    parser.add_argument('-c', '--config', default='configs/Pelee_VOC.py')
    parser.add_argument('-d', '--dataset', default='VOC', help='VOC or COCO dataset')
    parser.add_argument('--ngpu', default=3, type=int, help='gpus')
    parser.add_argument('--resume_net', default=None, help='resume net for retraining')
    parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
    parser.add_argument('-t', '--tensorboard', type=bool, default=False, help='Use tensorborad to show the Loss Graph')
    args = parser.parse_args(args=[])
    return args

tensor2numpy

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

train

def train(model, epoch, criterion, optimizer, data_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(data_loader):
        if cuda_gpu:
            data, target = data.cuda(), target.cuda()
            model.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)

        optimizer.zero_grad()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
		info_list=['Train VGG NET', epoch, epochs, batch_idx, len(data_loader), 
		           100. * batch_idx / len(data_loader), loss.item()]
		print_train_log(info_list)

val

def val(model, epoch, criterion, data_loader):
    model.eval()
    val_loss = 0
    correct = 0
    for data, target in data_loader:
        if cuda_gpu:
            data, target = data.cuda(), target.cuda()
            model.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        val_loss += criterion(output, target).data[0]
        
        pred = output.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()

    val_loss /= len(data_loader) # loss function already averages over batch size
    acc = correct / len(data_loader.dataset)
    print('Val set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(data_loader), 100. * acc))
        
    return (acc, val_loss)

save_checkpoint

import os, sys
def save_checkpoint(state, is_best, checkpoint_name):
    model_dir = os.path.dirname(checkpoint_name)
    model_fn = os.path.basename(checkpoint_name)
    # make dir if needed (should be non-empty)
    if model_dir!='' and not os.path.exists(model_dir):
        os.makedirs(model_dir)
        
    torch.save(state, checkpoint_name)
    if is_best:
        shutil.copyfile(checkpoint_name, os.path.join(model_dir,'best_' + model_fn))

調用方式:

save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_pred': best_pred,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint_name)

torchvision

from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt 
%matplotlib inline
 
 
my_trans=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
                                            
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
        utils.save_image(grid,'test01.png')
    break

配置文件解析----parse_config.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os, sys

from argparse import ArgumentParser
from collections import Iterable
from importlib import import_module

from addict import Dict

class ConfigDict(Dict):
    '''
    '''
    def __missing__(self, name):
        raise KeyError(name)

    def __getattr__(self, name):
        try:
            value = super(ConfigDict, self).__getattr__(name)
        except KeyError:
            ex = AttributeError("'{}' object has no attribute '{}'".format(
                self.__class__.__name__, name))
        except Exception as e:
            ex = e
        else:
            return value
        raise ex


def add_args(parser, cfg, prefix=''):
    '''
    '''
    for k, v in cfg.items():
        if isinstance(v, str):
            parser.add_argument('--' + prefix + k)
        elif isinstance(v, int):
            parser.add_argument('--' + prefix + k, type=int)
        elif isinstance(v, float):
            parser.add_argument('--' + prefix + k, type=float)
        elif isinstance(v, bool):
            parser.add_argument('--' + prefix + k, action='store_true')
        elif isinstance(v, dict):
            add_args(parser, v, k + '.')
        elif isinstance(v, Iterable):
            parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
        else:
            print('connot parse key {} of type {}'.format(prefix + k, type(v)))
            
    return parser


def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
    if not os.path.isfile(filename):
        raise IOError(msg_tmpl.format(filename))

        
class Config(object):
    """A facility for config and config files.
    It supports common file formats as configs: python/json/yaml. The interface
    is the same as a dict object and also allows access config values as
    attributes.
    Example:
        >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
        >>> cfg.a
        1
        >>> cfg.b
        {'b1': [0, 1]}
        >>> cfg.b.b1
        [0, 1]
        >>> cfg = Config.fromfile('tests/data/config/a.py')
        >>> cfg.filename
        "/home/kchen/projects/mmcv/tests/data/config/a.py"
        >>> cfg.item4
        'test'
        >>> cfg
        "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
        "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
    """
    @staticmethod
    def fromfile(filename):
        filename = os.path.abspath(os.path.expanduser(filename))
        check_file_exist(filename)
        if filename.endswith('.py'):
            module_name = os.path.basename(filename)[:-3]
            if '.' in module_name:
                raise ValueError('Dots are not allowed in config file path.')
                
            config_dir = os.path.dirname(filename)
            sys.path.insert(0, config_dir)
            mod = import_module(module_name)
            sys.path.pop(0)
            cfg_dict = {
                name: value
                for name, value in mod.__dict__.items()
                if not name.startswith('__')
            }
        elif filename.endswith(('.yaml', '.json')):
            import mmcv
            cfg_dict = mmcv.load(filename)
        else:
            raise IOError('Only py/yaml/json type are supported now!')
            
        return Config(cfg_dict, filename=filename)

    @staticmethod
    def auto_argparser(description=None):
        """Generate argparser from config file automatically (experimental)
        """
        partial_parser = ArgumentParser(description=description)
        partial_parser.add_argument('config', help='config file path')
        cfg_file = partial_parser.parse_known_args()[0].config
        cfg = Config.from_file(cfg_file)
        parser = ArgumentParser(description=description)
        parser.add_argument('config', help='config file path')
        add_args(parser, cfg)
        return parser, cfg

    def __init__(self, cfg_dict=None, filename=None):
        if cfg_dict is None:
            cfg_dict = dict()
        elif not isinstance(cfg_dict, dict):
            raise TypeError('cfg_dict must be a dict, but got {}'.format(
                type(cfg_dict)))

        super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
        super(Config, self).__setattr__('_filename', filename)
        if filename:
            with open(filename, 'r') as f:
                super(Config, self).__setattr__('_text', f.read())
        else:
            super(Config, self).__setattr__('_text', '')

    @property
    def filename(self):
        return self._filename

    @property
    def text(self):
        return self._text

    def __repr__(self):
        return 'Config (path: {}): {}'.format(self.filename,
                                              self._cfg_dict.__repr__())

    def __len__(self):
        return len(self._cfg_dict)

    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)

    def __getitem__(self, name):
        return self._cfg_dict.__getitem__(name)

    def __setattr__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setattr__(name, value)

    def __setitem__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setitem__(name, value)

    def __iter__(self):
        return iter(self._cfg_dict)

Finetune權值初始化

torch.save(pre_net.state_dict(), 'pretrained.pkl')  # save a pretrained model
pretrained_dict = torch.load('pretrained.pkl')      # load mode  param
model = Net()  # create a new model
net_state_dict = model.state_dict()  # get new model param
# remove pretrained_dict some layer don't need ny net_state_dict
pretrained_dict_new = {k:v for k,v in pretrained_dict.items() if k in net_state_dict}
net_state_dict.update(pretrained_dict_new)
net.load_state_dict(net_state_dict)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章