姿態估計1-06:HR-Net(人體姿態估算)-源碼無死角解析(2)-數據讀取,預處理

以下鏈接是個人關於HR-Net(人體姿態估算) 所有見解,如有錯誤歡迎大家指出,我會第一時間糾正。有興趣的朋友可以加微信:a944284742相互討論技術。若是幫助到了你什麼,一定要記得點贊!因爲這是對我最大的鼓勵。
姿態估計1-00:HR-Net(人體姿態估算)-目錄-史上最新無死角講解

前言

通過上一篇博客,可以在tools/train.py找到如下代碼:

    # 創建訓練以及測試數據的迭代器
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

其就是創建數據迭代器的過程,該篇博客我們就來分析與一下其具體實現過程。

coco.py

首先我們查看lib/dataset/coco.py文件,其中COCODataset初始化的相關函數註釋如下:

class COCODataset(JointsDataset):
    '''
    "keypoints": {
        0: "nose",
        1: "left_eye",
        2: "right_eye",
        3: "left_ear",
        4: "right_ear",
        5: "left_shoulder",
        6: "right_shoulder",
        7: "left_elbow",
        8: "right_elbow",
        9: "left_wrist",
        10: "right_wrist",
        11: "left_hip",
        12: "right_hip",
        13: "left_knee",
        14: "right_knee",
        15: "left_ankle",
        16: "right_ankle"
    },
	"skeleton": [
        [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
        [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
    '''
    def __init__(self, cfg, root, image_set, is_train, transform=None):
        super().__init__(cfg, root, image_set, is_train, transform)
        # nms 閾值,默認爲1
        self.nms_thre = cfg.TEST.NMS_THRE
        # 默認設置爲0
        self.image_thre = cfg.TEST.IMAGE_THRE
        # 是否使用軟nms,默認false
        self.soft_nms = cfg.TEST.SOFT_NMS
        # oks 閾值
        self.oks_thre = cfg.TEST.OKS_THRE
        # ==默認爲0.2
        self.in_vis_thre = cfg.TEST.IN_VIS_THRE
        # box文件,該文件主要記錄person的box
        self.bbox_file = cfg.TEST.COCO_BBOX_FILE
        # 是否使用ground truch
        self.use_gt_bbox = cfg.TEST.USE_GT_BBOX
        # 模型輸入圖象的寬和高
        self.image_width = cfg.MODEL.IMAGE_SIZE[0]
        self.image_height = cfg.MODEL.IMAGE_SIZE[1]
        # 輸入圖象寬和高的比例
        self.aspect_ratio = self.image_width * 1.0 / self.image_height
        # 標準化參數
        self.pixel_std = 200

        # 根據annotion文件,加載數據集信息,該處只加載了person關鍵點的數據
        self.coco = COCO(self._get_ann_file_keypoint())

        # deal with class names,獲得數據集中標註的類別,該處只有person一個類
        cats = [cat['name']
                for cat in self.coco.loadCats(self.coco.getCatIds())]

        # 所有類別前面,加上一個背景類
        self.classes = ['__background__'] + cats
        logger.info('=> classes: {}'.format(self.classes))
        # 計算包括背景所有類別的總數
        self.num_classes = len(self.classes)
        # 字典  類別名:類別編號
        self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
        # 字典  類別標籤編號:coco數據類別編號
        self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
        # 字典  coco數據類別編號:類別標籤編號
        self._coco_ind_to_class_ind = dict(
            [
                (self._class_to_coco_ind[cls], self._class_to_ind[cls])
                for cls in self.classes[1:]
            ]
        )

        # load image file names
        # 獲得包含person圖象的標號
        self.image_set_index = self._load_image_set_index()
        # 計算總共多少圖片
        self.num_images = len(self.image_set_index)
        logger.info('=> num_images: {}'.format(self.num_images))

        # 需要檢測關鍵點的數目
        self.num_joints = 17
        # 人體水平對稱關鍵印射
        self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
                           [9, 10], [11, 12], [13, 14], [15, 16]]
        # ?? 父母ids
        self.parent_ids = None

        # 定義上半身和下半身關鍵點
        self.upper_body_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
        self.lower_body_ids = (11, 12, 13, 14, 15, 16)

        # 分別定義每個關鍵點的權重
        self.joints_weight = np.array(
            [
                1., 1., 1., 1., 1., 1., 1., 1.2, 1.2,
                1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, 1.5
            ],
            dtype=np.float32
        ).reshape((self.num_joints, 1))


        self.db = self._get_db()

        if is_train and cfg.DATASET.SELECT_DATA:
            self.db = self.select_data(self.db)

        logger.info('=> load {} samples'.format(len(self.db)))

    def _get_ann_file_keypoint(self):
        """ self.root / annotations / person_keypoints_train2017.json """
        prefix = 'person_keypoints' \
            if 'test' not in self.image_set else 'image_info'
        return os.path.join(
            self.root,
            'annotations',
            prefix + '_' + self.image_set + '.json'
        )

    def _load_image_set_index(self):
        """ image id: int """
        image_ids = self.coco.getImgIds()
        return image_ids

    def _get_db(self):
        # 如果是進行訓練或者設置self.use_gt_bbo==Ture
        if self.is_train or self.use_gt_bbox:
            # use ground truth bbox
            gt_db = self._load_coco_keypoint_annotations()
        # 使用目標檢測模型
        else:
            # use bbox from detection
            # 使用來自檢測結果的box
            gt_db = self._load_coco_person_detection_results()
        return gt_db

    # 加載coco所有數據關鍵點信息
    def _load_coco_keypoint_annotations(self):
        """ ground truth bbox and keypoints """
        gt_db = []
        for index in self.image_set_index:
            gt_db.extend(self._load_coco_keypoint_annotation_kernal(index))
        return gt_db

    def _load_coco_keypoint_annotation_kernal(self, index):
        """
        根據index,加載單個person關鍵點數據信息
        coco ann: [u'segmentation', u'area', u'iscrowd', u'image_id', u'bbox', u'category_id', u'id']
        iscrowd:
            crowd instances are handled by marking their overlaps with all categories to -1
            and later excluded in training
        bbox:
            [x1, y1, w, h]
        :param index: coco image id
        :return: db entry
        """
        # 獲得包含person圖片信息
        im_ann = self.coco.loadImgs(index)[0]
        # 獲得圖片的大小
        width = im_ann['width']
        height = im_ann['height']
        # 獲得包含person圖片的註釋id
        annIds = self.coco.getAnnIds(imgIds=index, iscrowd=False)
        # 根據註釋id,獲得對應的註釋信息
        objs = self.coco.loadAnns(annIds)


        # sanitize bboxes
        #對box進行簡單的清理,清除一些不符合邏輯的box
        valid_objs = []
        for obj in objs:
            x, y, w, h = obj['bbox']
            x1 = np.max((0, x))
            y1 = np.max((0, y))
            x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
            y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
            if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
                obj['clean_bbox'] = [x1, y1, x2-x1, y2-y1]
                valid_objs.append(obj)
        objs = valid_objs

        rec = []
        for obj in objs:
            # 獲得物體的類別id,person默認爲1,如果不爲1,則continue跳過該obj
            cls = self._coco_ind_to_class_ind[obj['category_id']]
            if cls != 1:
                continue

            # ignore objs without keypoints annotation,
            # 如果該obj沒有包含keypoints的信息也直接跳過
            if max(obj['keypoints']) == 0:
                continue

            # 獲取人體的關節信息,使用3維表示
            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
            joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
            for ipt in range(self.num_joints):
                joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
                joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
                joints_3d[ipt, 2] = 0
                t_vis = obj['keypoints'][ipt * 3 + 2]
                if t_vis > 1:
                    t_vis = 1
                joints_3d_vis[ipt, 0] = t_vis
                joints_3d_vis[ipt, 1] = t_vis
                joints_3d_vis[ipt, 2] = 0

            # 獲取box的中心點
            center, scale = self._box2cs(obj['clean_bbox'][:4])


            rec.append({
                'image': self.image_path_from_index(index),
                'center': center,
                'scale': scale,
                'joints_3d': joints_3d,
                'joints_3d_vis': joints_3d_vis,
                'filename': '',
                'imgnum': 0,
            })

        return rec

通過上面的註釋可以知道,通過COCODataset的初始化函數,我們主要是獲得一個rec的數據,其中包含了,coco中所有人體,以及對應關鍵點的信息。同時附帶圖片路徑,以及標準化縮放比例等信息。

但是到這裏還沒有結束,我們還要進一步處理,因爲在計算 loss 的時候,我們需要的是heatmap。也就是接下來,我們需要根據rec中的信息,讀取圖片像素(用於訓練),同時把標籤信息(人體關鍵點位置)轉化爲heatmap,其實現的過程位於代碼lib/dataset/JointsDataset.py。

JointsDataset.py

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao ([email protected])
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import random

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from utils.transforms import fliplr_joints


logger = logging.getLogger(__name__)


class JointsDataset(Dataset):
    def __init__(self, cfg, root, image_set, is_train, transform=None):
        # 人體關節的數目
        self.num_joints = 0
        # 像素標準化參數
        self.pixel_std = 200
        # 水平翻轉
        self.flip_pairs = []
        # 父母ID==
        self.parent_ids = []
        # 是否進行訓練
        self.is_train = is_train
        # 訓練數據根目錄
        self.root = root
        # 圖片數據集名稱,如‘train2017’
        self.image_set = image_set
        # 輸出目錄
        self.output_path = cfg.OUTPUT_DIR
        # 數據格式如‘jpg’
        self.data_format = cfg.DATASET.DATA_FORMAT
        # 縮放因子
        self.scale_factor = cfg.DATASET.SCALE_FACTOR
        # 旋轉角度
        self.rotation_factor = cfg.DATASET.ROT_FACTOR
        # 是否進行水平翻轉
        self.flip = cfg.DATASET.FLIP
        # 人體一半關鍵點的數目,默認爲8
        self.num_joints_half_body = cfg.DATASET.NUM_JOINTS_HALF_BODY
        # 人體一半的概率
        self.prob_half_body = cfg.DATASET.PROB_HALF_BODY
        # 圖片格式,默認爲rgb
        self.color_rgb = cfg.DATASET.COLOR_RGB
        # 目標數據的類型,默認爲高斯分佈
        self.target_type = cfg.MODEL.TARGET_TYPE
        # 網絡訓練圖片大小,如[192,256]
        self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)
        # 標籤熱圖的大小
        self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)
        # sigma參數,默認爲2
        self.sigma = cfg.MODEL.SIGMA
        # 是否對每個關節使用不同的權重,默認爲false
        self.use_different_joints_weight = cfg.LOSS.USE_DIFFERENT_JOINTS_WEIGHT
        # 關節權重
        self.joints_weight = 1
        # 數據增強,轉換等
        self.transform = transform
        # 用於保存訓練數據的信息,由子類提供
        self.db = []

    # 由子類實現
    def _get_db(self):
        raise NotImplementedError

    # 由子類實現
    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
        raise NotImplementedError


    def half_body_transform(self, joints, joints_vis):
        """
        只有一半身體數據轉換
        :param joints: 關鍵點位置,shape=[17,3], 因爲使用2D表示,第三維度都爲0
        :param joints_vis: 表示關鍵點是否可見,shape=[17,3]
        :return:
        """
        # 上半部分關節
        upper_joints = []
        # 下半部分關節
        lower_joints = []


        for joint_id in range(self.num_joints):
            # 如果該關鍵點能被看見
            if joints_vis[joint_id][0] > 0:
                # 如果關鍵點爲上身部分關鍵點
                if joint_id in self.upper_body_ids:
                    upper_joints.append(joints[joint_id])
                # 如果關鍵點爲下身部分關鍵點
                else:
                    lower_joints.append(joints[joint_id])

        # 二分之一的概率進行關鍵點選擇,選擇上半身或者下半身關鍵點
        if np.random.randn() < 0.5 and len(upper_joints) > 2:
            selected_joints = upper_joints
        else:
            selected_joints = lower_joints \
                if len(lower_joints) > 2 else upper_joints

        # 如果該樣本的關鍵點小於兩個,則返回None,無需進行訓練
        if len(selected_joints) < 2:
            return None, None

        #
        selected_joints = np.array(selected_joints, dtype=np.float32)

        # 求得關鍵點x,y的平均座標
        center = selected_joints.mean(axis=0)[:2]

        # 左上角座標
        left_top = np.amin(selected_joints, axis=0)
        # 右下角座標
        right_bottom = np.amax(selected_joints, axis=0)

        # 獲得飽覽所有關鍵點的最小寬和高
        w = right_bottom[0] - left_top[0]
        h = right_bottom[1] - left_top[1]


        # 對w或者h進行擴大,確保w/h的比例爲0.75
        if w > self.aspect_ratio * h:
            h = w * 1.0 / self.aspect_ratio
        elif w < self.aspect_ratio * h:
            w = h * self.aspect_ratio

        # 記錄w,h的縮放比例
        scale = np.array(
            [
                w * 1.0 / self.pixel_std,
                h * 1.0 / self.pixel_std
            ],
            dtype=np.float32
        )


        scale = scale * 1.5

        return center, scale


    def __len__(self,):
        return len(self.db)


    def __getitem__(self, idx):
        # 根據 idx 從db獲取樣本信息
        db_rec = copy.deepcopy(self.db[idx])
        # 獲取圖像名
        image_file = db_rec['image']

        # filename與imgnum暫時沒有使用
        filename = db_rec['filename'] if 'filename' in db_rec else ''
        imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''

        # 如果數據格式爲zip則解壓
        if self.data_format == 'zip':
            from utils import zipreader
            data_numpy = zipreader.imread(
                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
            )
        # 否則直接讀取圖像,獲得像素值
        else:
            data_numpy = cv2.imread(
                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
            )

        # 轉化爲rgb格式
        if self.color_rgb:
            data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)

        # 如果讀取到的數據不爲numpy格式則報錯
        if data_numpy is None:
            logger.error('=> fail to read {}'.format(image_file))
            raise ValueError('Fail to read {}'.format(image_file))

        # 獲取人體關鍵點座標,
        joints = db_rec['joints_3d']
        joints_vis = db_rec['joints_3d_vis']

        # 獲取訓練樣本轉化之後的center以及scale,
        c = db_rec['center']
        s = db_rec['scale']

        # 如果訓練樣本中沒有設置score,則加載該屬性,並且設置爲1
        score = db_rec['score'] if 'score' in db_rec else 1
        r = 0

        # 如果是進行訓練,
        if self.is_train:
            # 如果可見關鍵點大於人體一半關鍵點, 並且生成的隨機數小於self.prob_half_body=0.3
            if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body and np.random.rand() < self.prob_half_body):
                # 重新調整center,scale
                c_half_body, s_half_body = self.half_body_transform(joints, joints_vis)

                if c_half_body is not None and s_half_body is not None:
                    c, s = c_half_body, s_half_body

            # 縮放因子scale_factor=0.35,以及旋轉因子rotation_factor=0.35
            sf = self.scale_factor
            rf = self.rotation_factor

            # s大小爲[1-0.35=0.65,1+0.35=1.35]之間
            s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)

            # r大小爲[-2*45=95,2*45=90]之間
            r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
                if random.random() <= 0.6 else 0

            # 進行數據水平翻轉
            if self.flip and random.random() <= 0.5:
                data_numpy = data_numpy[:, ::-1, :]
                joints, joints_vis = fliplr_joints(
                    joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
                c[0] = data_numpy.shape[1] - c[0] - 1

        # 進行反射變換,樣本數據關鍵點發生角度旋轉之後,每個像素也旋轉到對應位置.
        # 獲得旋轉矩陣
        trans = get_affine_transform(c, s, r, self.image_size)
        # 根據旋轉矩陣進行反射變換
        input = cv2.warpAffine(
            data_numpy,
            trans,
            (int(self.image_size[0]), int(self.image_size[1])),
            flags=cv2.INTER_LINEAR)


        # 進行正則化,形狀改變等
        if self.transform:
            input = self.transform(input)

        # 對人體關鍵點也進行反射變換
        for i in range(self.num_joints):
            if joints_vis[i, 0] > 0.0:
                joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)

        # 獲得ground truch, 熱圖target[17,64,48], target_weight[17,1]
        target, target_weight = self.generate_target(joints, joints_vis)


        target = torch.from_numpy(target)
        target_weight = torch.from_numpy(target_weight)

        meta = {
            'image': image_file,
            'filename': filename,
            'imgnum': imgnum,
            'joints': joints,
            'joints_vis': joints_vis,
            'center': c,
            'scale': s,
            'rotation': r,
            'score': score
        }
        return input, target, target_weight, meta


    def select_data(self, db):
        db_selected = []
        for rec in db:
            num_vis = 0
            joints_x = 0.0
            joints_y = 0.0
            for joint, joint_vis in zip(
                    rec['joints_3d'], rec['joints_3d_vis']):
                if joint_vis[0] <= 0:
                    continue
                num_vis += 1

                joints_x += joint[0]
                joints_y += joint[1]
            if num_vis == 0:
                continue

            joints_x, joints_y = joints_x / num_vis, joints_y / num_vis

            area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std**2)
            joints_center = np.array([joints_x, joints_y])
            bbox_center = np.array(rec['center'])
            diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2)
            ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area))

            metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
            if ks > metric:
                db_selected.append(rec)

        logger.info('=> num db: {}'.format(len(db)))
        logger.info('=> num selected db: {}'.format(len(db_selected)))
        return db_selected

    def generate_target(self, joints, joints_vis):
        '''
        :param joints:  [num_joints, 3]
        :param joints_vis: [num_joints, 3]
        :return: target, target_weight(1: visible, 0: invisible)
        '''
        # target_weight形狀爲[17,1]
        target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
        target_weight[:, 0] = joints_vis[:, 0]

        # 檢測製作熱圖的方式是否爲gaussian,如果不是則報錯
        assert self.target_type == 'gaussian', \
            'Only support gaussian map now!'

        # 如果使用高斯模糊的方法制作熱圖
        if self.target_type == 'gaussian':
            # 形狀爲[17, 64, 48]
            target = np.zeros((self.num_joints,
                               self.heatmap_size[1],
                               self.heatmap_size[0]),
                              dtype=np.float32)

            # self.sigma 默認爲2, tmp_size=6
            tmp_size = self.sigma * 3

            # 爲每個關鍵點生成熱圖target以及對應的熱圖權重target_weight
            for joint_id in range(self.num_joints):
                # 先計算出原圖到輸出熱圖的縮小倍數
                feat_stride = self.image_size / self.heatmap_size

                # 計算出輸入原圖的關鍵點,轉換到熱圖的位置
                mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
                mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)


                # Check that any part of the gaussian is in-bounds
                # 根據tmp_size參數,計算出關鍵點範圍左上角和右下角座標
                ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
                br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]

                # 判斷該關鍵點是否處於熱圖之外,如果處於熱圖之外,則把該熱圖對應的target_weight設置爲0,然後continue
                if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
                        or br[0] < 0 or br[1] < 0:
                    # If not, just return the image as is
                    target_weight[joint_id] = 0
                    continue


                # # Generate gaussian
                # 產生高斯分佈的大小
                size = 2 * tmp_size + 1
                # x[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12.]
                x = np.arange(0, size, 1, np.float32)
                # y[[ 0.][ 1.][ 2.][ 3.][ 4.][ 5.][ 6.][ 7.][ 8.][ 9.][10.][11.][12.]]
                y = x[:, np.newaxis]
                # x0 = y0 = 6
                x0 = y0 = size // 2
                # The gaussian is not normalized, we want the center value to equal 1
                # g形狀[13,13], 該數組中間的[7,7]=1,離開該中心點越遠數值越小
                g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))


                # Usable gaussian range,
                # 判斷邊界,獲得有效高斯分佈的範圍
                g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
                g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]

                # Image range
                # 判斷邊界,獲得有有效的圖片像素邊界
                img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
                img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])


                # 如果該關鍵點對應的target_weight>0.5(即表示該關鍵點可見),則把關鍵點附近的特徵點賦值成gaussian
                v = target_weight[joint_id]
                if v > 0.5:
                    target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]


        # 如果各個關鍵點訓練權重不一樣
        if self.use_different_joints_weight:
            target_weight = np.multiply(target_weight, self.joints_weight)

        # img = np.transpose(target.copy(),[1,2,0])*255
        # img = img[:,:,0].astype(np.uint8)
        # img = np.expand_dims(img,axis=-1)
        # cv2.imwrite('./test.jpg', img)
        return target, target_weight

在代碼的最後,可以看到:

        # img = np.transpose(target.copy(),[1,2,0])*255
        # img = img[:,:,0].astype(np.uint8)
        # img = np.expand_dims(img,axis=-1)
        # cv2.imwrite('./test.jpg', img)

大家可以取消註釋,然後重新運行訓練代碼,可以看到保存的圖片顯示如下:
在這裏插入圖片描述
其中心白點的位置,爲關鍵點的位置,這就是熱圖。但是這裏我們只保存了一個關鍵點的熱圖。

小結

到這裏,我們對數據的處理過程算是比較瞭解了,接下來我們要對網絡框架進行解剖。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章