centernet下訓練自己的數據

目錄

一.數據準備

1.製作COCO數據集

2.計算數據集的均值方差

二.代碼修改

1.新建類別

 2.加入dataset

 3.修改/src/lib/opts.py

4.修改src/lib/utils/debugger.py文件

二 訓練與測試:

1訓練:

2測試:

3繪製loss曲線


參照博客:

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

一.數據準備

1.製作COCO數據集

這裏我用的是VOC數據集轉COCO

參照博客:

https://blog.csdn.net/weixin_41765699/article/details/100124689

主要trian,val,test三個文件夾下txt轉化爲json

2.計算數據集的均值方差

import cv2, os, argparse
import numpy as np
from tqdm import tqdm


def main():
    dirs = '/home/zbb/CenterNet/data/plane/images'   # 修改你自己的圖片路徑
    img_file_names = os.listdir(dirs)
    m_list, s_list = [], []
    for img_filename in tqdm(img_file_names):
        img = cv2.imread(dirs + '/' + img_filename)
        img = img / 255.0
        m, s = cv2.meanStdDev(img)
        m_list.append(m.reshape((3,)))
        s_list.append(s.reshape((3,)))
    m_array = np.array(m_list)
    s_array = np.array(s_list)
    m = m_array.mean(axis=0, keepdims=True)
    s = s_array.mean(axis=0, keepdims=True)
    print("mean = ", m[0][::-1])
    print("std = ", s[0][::-1])

if __name__ == '__main__':
    main()

二.代碼修改

1.新建類別

src/lib/datasets/dataset裏面新建一個“plane. py”,文件內容照着文件夾下coco.py改成自己的

1).把COCO關鍵字改爲Plane

2)路徑格式

使用相對路徑報錯,改成了絕對路徑

3)訓練修改

修改爲val,train,測試再修改回來

類別名字和類別id改成自己

 2.加入dataset

將數據集加入src/lib/datasets/dataset_factory裏面

一定要記得import,否則會報你的類別未定義

 3.修改/src/lib/opts.py

將自己的數據集設爲默認數據集,加入到help裏面

 修改ctdet任務使用的默認數據集爲新添加的數據集,如下(修改分辨率,類別數,均值,方差,數據集名字):

4.修改src/lib/utils/debugger.py文件

變成自己數據的類別和名字,前後數據集名字一定保持一致

再加上自己數據的類別,不包括背景__background__ 

二 訓練與測試:

1訓練:

 輸入命令:

python main.py ctdet --exp_id coco_dla --batch_size 4 --master_batch 1 --lr 1.25e-4  --gpus 0,1

如果顯示顯存不夠之類的那種錯誤,需要在opts.py文件中將--num_workers改成0,batch_size小

2測試:

  建立的plane.py中修改如下部分,加入if split == ‘test’:…,作用是當test時指定標籤文件爲之前建立的測試文件     

   運行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /home/zbb/CenterNet/exp/ctdet/coco_dla/model_best.pth

結果:

其中,一般使用的是第二行,也就是IOU=0.5,全區域的AP值,其他的分別是不同IOU以及不同目標尺寸區域的結果。 

3繪製loss曲線

訓練生成的日誌文件一般在exp/ctdet/../../logs.txt

參照博主但是,val—loss繪製不好,先繪製total—loss

import matplotlib.pyplot as plt
import numpy as np


def plot_loss_curve(log_file):
    loss_data = open(log_file)
    all_lines = loss_data.readlines()
    print(all_lines[4].split(' '))
    # losses
    total_loss = []  # 4
    hm_loss = []  # 7
    wh_loss = []  # 10
    off_loss = []  # 13
    val_loss = []  # 19
    spend_time = []  # 16
    num_lines = len(all_lines)
    for line in range(num_lines):
        total_loss1 = all_lines[line].split(' ')[4]
        hm_loss1 = all_lines[line].split(' ')[7]
        wh_loss1 = all_lines[line].split(' ')[10]
        off_loss1 = all_lines[line].split(' ')[13]
        #val_loss1 = all_lines[line].split(' ')[19]
        spend_time1 = all_lines[line].split(' ')[16]
        print(total_loss1)
        print(spend_time1)

        total_loss.append(float(total_loss1))
        #val_loss.append(float(val_loss1))
        hm_loss.append(float(hm_loss1))
        wh_loss.append(float(wh_loss1))
        off_loss.append(float(off_loss1))
        spend_time.append(float(spend_time1))
    return total_loss

if __name__ == '__main__':
    # 標準圖形繪製
    # sns.set()
    loss_res18 = plot_loss_curve(
        '/home/zbb/CenterNet/exp/ctdet/coco_dla/logs_2019-10-17-15-41/log.txt')  # 讀取訓練時生成的日誌文件
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(111)
    ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)  # 這個label是圖線自己的標籤;

    # ax.set_xlim([0, 800])                                      # 設置刻度;
    # ax.set_xticks(range(0, 500, 100))                          # 設置顯示的刻度;
    # ax.set_yticklabels(['jan', 'feb', 'mar'])                  # 設置刻度標籤;
    ax.set_xlabel('epochs')  # 設置座標軸標籤;
    ax.set_ylabel('loss_value')
    ax.text(8750, 20, "plane", color='red')  # 加入文本
    ax.set_title('loss_of_CenterNet')
    ax.legend(loc='best')  # 將圖例擺放在不遮擋圖線的位置即可
    ax.grid()  # 添加網格
    plt.savefig('/home/zbb/CenterNet/loss_of_CenterNet.png')  # 保存文件到指定文件夾
    plt.show()

total——loss結果圖:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章