打印訓練輸出
from termcolor import cprint
def print_train_log(info_list):
"""
print trainning log infos
"""
cprint('{} Epoch: {}/{} [{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(*info_list), 'yellow', attrs=['bold'])
def print_info(info, _type=None):
"""
print log
"""
if _type is not None:
if isinstance(info, str):
cprint(info, _type[0], attrs=[_type[1]])
elif isinstance(info, list):
for i in range(info):
cprint(i, _type[0], attrs=[_type[1]])
else:
print(info)
調用方式:
info_list=['Train peleeNet', epoch, epochs, i_batch, len(train_loader),
100. * i_batch / len(train_loader), loss.item()]
print_train_log(info_list)
or
print_info(imgs_result_path, ['yellow', 'bold'])
設置日誌記錄
from logging import handlers
from colorlog import ColoredFormatter
def setup_logger(log_filename, level=logging.DEBUG, when='midnight', back_count=0):
"""
:brief 日誌記錄
:param log_filename: 日誌名稱
:param level: 日誌等級
:param when: 間隔時間:
S:秒
M:分
H:小時
D:天
W:每星期(interval==0時代表星期一)
midnight: 每天凌晨
:param back_count: 備份文件的個數,若超過該值,就會自動刪除
:return: logger
"""
logger = logging.getLogger(log_filename)
logger.setLevel(level)
log_file_path = os.path.join(log_filename)
# set logger formater
formatter = ColoredFormatter(
"%(asctime)s %(log_color)s%(levelname)s %(reset)s %(filename)s[%(lineno)d]: %(message)s",
datefmt='%Y-%m-%d %H:%M:%S', reset=True,
log_colors={'DEBUG': 'blue', 'INFO': 'green', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red'})
# output to console
ch = logging.StreamHandler()
ch.setLevel(level)
# output to file
fh = logging.handlers.TimedRotatingFileHandler(filename=log_file_path,when=when, backupCount=back_count, encoding='utf-8')
fh.setLevel(level)
# set logger formater
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add to logger
logger.addHandler(fh)
logger.addHandler(ch)
return logger
python運行傳參設置
def parse_args():
'''
@function name:
@description:
@args:
@return:
'''
parser = argparse.ArgumentParser(description='Pelee Training')
parser.add_argument('-c', '--config', default='configs/Pelee_VOC.py')
parser.add_argument('-d', '--dataset', default='VOC', help='VOC or COCO dataset')
parser.add_argument('--ngpu', default=3, type=int, help='gpus')
parser.add_argument('--resume_net', default=None, help='resume net for retraining')
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
parser.add_argument('-t', '--tensorboard', type=bool, default=False, help='Use tensorborad to show the Loss Graph')
args = parser.parse_args(args=[])
return args
tensor2numpy
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
train
def train(model, epoch, criterion, optimizer, data_loader):
model.train()
for batch_idx, (data, target) in enumerate(data_loader):
if cuda_gpu:
data, target = data.cuda(), target.cuda()
model.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
optimizer.zero_grad()
loss = criterion(output, target)
loss.backward()
optimizer.step()
info_list=['Train VGG NET', epoch, epochs, batch_idx, len(data_loader),
100. * batch_idx / len(data_loader), loss.item()]
print_train_log(info_list)
val
def val(model, epoch, criterion, data_loader):
model.eval()
val_loss = 0
correct = 0
for data, target in data_loader:
if cuda_gpu:
data, target = data.cuda(), target.cuda()
model.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
val_loss += criterion(output, target).data[0]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
val_loss /= len(data_loader) # loss function already averages over batch size
acc = correct / len(data_loader.dataset)
print('Val set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss, correct, len(data_loader), 100. * acc))
return (acc, val_loss)
save_checkpoint
import os, sys
def save_checkpoint(state, is_best, checkpoint_name):
model_dir = os.path.dirname(checkpoint_name)
model_fn = os.path.basename(checkpoint_name)
# make dir if needed (should be non-empty)
if model_dir!='' and not os.path.exists(model_dir):
os.makedirs(model_dir)
torch.save(state, checkpoint_name)
if is_best:
shutil.copyfile(checkpoint_name, os.path.join(model_dir,'best_' + model_fn))
調用方式:
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_pred': best_pred,
'optimizer': optimizer.state_dict(),
}, is_best, checkpoint_name)
torchvision
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inline
my_trans=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
for i_batch, img in enumerate(train_loader):
if i_batch == 0:
print(img[1])
fig = plt.figure()
grid = utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
utils.save_image(grid,'test01.png')
break
配置文件解析----parse_config.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os, sys
from argparse import ArgumentParser
from collections import Iterable
from importlib import import_module
from addict import Dict
class ConfigDict(Dict):
'''
'''
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError("'{}' object has no attribute '{}'".format(
self.__class__.__name__, name))
except Exception as e:
ex = e
else:
return value
raise ex
def add_args(parser, cfg, prefix=''):
'''
'''
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument('--' + prefix + k)
elif isinstance(v, int):
parser.add_argument('--' + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument('--' + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument('--' + prefix + k, action='store_true')
elif isinstance(v, dict):
add_args(parser, v, k + '.')
elif isinstance(v, Iterable):
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
else:
print('connot parse key {} of type {}'.format(prefix + k, type(v)))
return parser
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not os.path.isfile(filename):
raise IOError(msg_tmpl.format(filename))
class Config(object):
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def fromfile(filename):
filename = os.path.abspath(os.path.expanduser(filename))
check_file_exist(filename)
if filename.endswith('.py'):
module_name = os.path.basename(filename)[:-3]
if '.' in module_name:
raise ValueError('Dots are not allowed in config file path.')
config_dir = os.path.dirname(filename)
sys.path.insert(0, config_dir)
mod = import_module(module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
elif filename.endswith(('.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
else:
raise IOError('Only py/yaml/json type are supported now!')
return Config(cfg_dict, filename=filename)
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)
"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.from_file(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument('config', help='config file path')
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but got {}'.format(
type(cfg_dict)))
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if filename:
with open(filename, 'r') as f:
super(Config, self).__setattr__('_text', f.read())
else:
super(Config, self).__setattr__('_text', '')
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
def __repr__(self):
return 'Config (path: {}): {}'.format(self.filename,
self._cfg_dict.__repr__())
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
Finetune權值初始化
torch.save(pre_net.state_dict(), 'pretrained.pkl') # save a pretrained model
pretrained_dict = torch.load('pretrained.pkl') # load mode param
model = Net() # create a new model
net_state_dict = model.state_dict() # get new model param
# remove pretrained_dict some layer don't need ny net_state_dict
pretrained_dict_new = {k:v for k,v in pretrained_dict.items() if k in net_state_dict}
net_state_dict.update(pretrained_dict_new)
net.load_state_dict(net_state_dict)