“華爲雲杯”2019人工智能創新應用大賽:版本二customize_service.py文件

# -*- coding: utf-8 -*-
from PIL import Image
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from model_service.pytorch_model_service import PTServingBaseService

import time
from metric.metrics_manager import MetricsManager
import log
logger = log.getLogger(__name__)
input_size = 456

class ImageClassificationService(PTServingBaseService):
    def __init__(self, model_name, model_path):
        self.model_name = model_name
        self.model_path = model_path

        self.model = models.__dict__['resnet50'](num_classes=54)
        self.use_cuda = False
        if torch.cuda.is_available():
            print('Using GPU for inference')
            self.use_cuda = True
            checkpoint = torch.load(self.model_path)
            self.model = torch.nn.DataParallel(self.model).cuda()
            self.model.load_state_dict(checkpoint['state_dict'])
        else:
            print('Using CPU for inference')
            checkpoint = torch.load(self.model_path, map_location='cpu')
            self.model.load_state_dict(state_dict)

        self.model.eval()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        self.transforms = transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            self.normalize
        ])

        self.label_id_name_dict = \
            {
                "0": "工藝品/仿唐三彩",
                "1": "工藝品/仿宋木葉盞",
                "2": "工藝品/布貼繡",
                "3": "工藝品/景泰藍",
                "4": "工藝品/木馬勺臉譜",
                "5": "工藝品/柳編",
                "6": "工藝品/葡萄花鳥紋銀香囊",
                "7": "工藝品/西安剪紙",
                "8": "工藝品/陝歷博唐妞系列",
                "9": "景點/關中書院",
                "10": "景點/兵馬俑",
                "11": "景點/南五臺",
                "12": "景點/大興善寺",
                "13": "景點/大觀樓",
                "14": "景點/大雁塔",
                "15": "景點/小雁塔",
                "16": "景點/未央宮城牆遺址",
                "17": "景點/水陸庵壁塑",
                "18": "景點/漢長安城遺址",
                "19": "景點/西安城牆",
                "20": "景點/鐘樓",
                "21": "景點/長安華嚴寺",
                "22": "景點/阿房宮遺址",
                "23": "民俗/嗩吶",
                "24": "民俗/皮影",
                "25": "特產/臨潼火晶柿子",
                "26": "特產/山茱萸",
                "27": "特產/玉器",
                "28": "特產/閻良甜瓜",
                "29": "特產/陝北紅小豆",
                "30": "特產/高陵冬棗",
                "31": "美食/八寶玫瑰鏡糕",
                "32": "美食/涼皮",
                "33": "美食/涼魚",
                "34": "美食/德懋恭水晶餅",
                "35": "美食/攪團",
                "36": "美食/枸杞燉銀耳",
                "37": "美食/柿子餅",
                "38": "美食/漿水面",
                "39": "美食/灌湯包",
                "40": "美食/燒肘子",
                "41": "美食/石子餅",
                "42": "美食/神仙粉",
                "43": "美食/粉湯羊血",
                "44": "美食/羊肉泡饃",
                "45": "美食/肉夾饃",
                "46": "美食/蕎麪餄餎",
                "47": "美食/菠菜面",
                "48": "美食/蜂蜜涼糉子",
                "49": "美食/蜜餞張口酥餃",
                "50": "美食/西安油茶",
                "51": "美食/貴妃雞翅",
                "52": "美食/醪糟",
                "53": "美食/金線油塔"
            }

    def _preprocess(self, data):
        preprocessed_data = {}
        for k, v in data.items():
            for file_name, file_content in v.items():
                img = Image.open(file_content)
                img = self.transforms(img)
                preprocessed_data[k] = img
        return preprocessed_data

    def _inference(self, data):
        img = data["input_img"]
        img = img.unsqueeze(0)

        if self.use_cuda:
            img = img.cuda()

        with torch.no_grad():
            pred_score = self.model(img)
            pred_score = F.softmax(pred_score.data, dim=1)
            if pred_score is not None:
                pred_label = torch.argsort(pred_score[0], descending=True)[:1][0].item()
                pred_label = int(pred_label)
                result = {'result': self.label_id_name_dict[str(pred_label)]}
            else:
                result = {'result': 'predict score is None'}

        return result

    def _postprocess(self, data):
        return data

    def inference(self, data):
        """
        Wrapper function to run preprocess, inference and postprocess functions.

        Parameters
        ----------
        data : map of object
            Raw input from request.

        Returns
        -------
        list of outputs to be sent back to client.
            data to be sent back
        """
        pre_start_time = time.time()
        data = self._preprocess(data)
        infer_start_time = time.time()

        # Update preprocess latency metric
        pre_time_in_ms = (infer_start_time - pre_start_time) * 1000
        logger.info('preprocess time: ' + str(pre_time_in_ms) + 'ms')

        if self.model_name + '_LatencyPreprocess' in MetricsManager.metrics:
            MetricsManager.metrics[self.model_name + '_LatencyPreprocess'].update(pre_time_in_ms)

        data = self._inference(data)
        infer_end_time = time.time()
        infer_in_ms = (infer_end_time - infer_start_time) * 1000

        logger.info('infer time: ' + str(infer_in_ms) + 'ms')
        data = self._postprocess(data)

        # Update inference latency metric
        post_time_in_ms = (time.time() - infer_end_time) * 1000
        logger.info('postprocess time: ' + str(post_time_in_ms) + 'ms')
        if self.model_name + '_LatencyInference' in MetricsManager.metrics:
            MetricsManager.metrics[self.model_name + '_LatencyInference'].update(post_time_in_ms)

        # Update overall latency metric
        if self.model_name + '_LatencyOverall' in MetricsManager.metrics:
            MetricsManager.metrics[self.model_name + '_LatencyOverall'].update(pre_time_in_ms + post_time_in_ms)

        logger.info('latency: ' + str(pre_time_in_ms + infer_in_ms + post_time_in_ms) + 'ms')
        data['latency_time'] = pre_time_in_ms + infer_in_ms + post_time_in_ms
        return data

發佈了113 篇原創文章 · 獲贊 107 · 訪問量 23萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章