目錄
4.修改src/lib/utils/debugger.py文件
參照博客:
https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit
https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit
一.數據準備
1.製作COCO數據集
這裏我用的是VOC數據集轉COCO
參照博客:
https://blog.csdn.net/weixin_41765699/article/details/100124689
主要trian,val,test三個文件夾下txt轉化爲json
2.計算數據集的均值方差
import cv2, os, argparse
import numpy as np
from tqdm import tqdm
def main():
dirs = '/home/zbb/CenterNet/data/plane/images' # 修改你自己的圖片路徑
img_file_names = os.listdir(dirs)
m_list, s_list = [], []
for img_filename in tqdm(img_file_names):
img = cv2.imread(dirs + '/' + img_filename)
img = img / 255.0
m, s = cv2.meanStdDev(img)
m_list.append(m.reshape((3,)))
s_list.append(s.reshape((3,)))
m_array = np.array(m_list)
s_array = np.array(s_list)
m = m_array.mean(axis=0, keepdims=True)
s = s_array.mean(axis=0, keepdims=True)
print("mean = ", m[0][::-1])
print("std = ", s[0][::-1])
if __name__ == '__main__':
main()
二.代碼修改
1.新建類別
src/lib/datasets/dataset
裏面新建一個“plane. py”,文件內容照着文件夾下coco.py改成自己的
1).把COCO關鍵字改爲Plane
2)路徑格式
使用相對路徑報錯,改成了絕對路徑
3)訓練修改
修改爲val,train,測試再修改回來
類別名字和類別id改成自己
2.加入dataset
將數據集加入src/lib/datasets/dataset_factory
裏面
一定要記得import,否則會報你的類別未定義
3.修改/src/lib/opts.py
將自己的數據集設爲默認數據集,加入到help裏面
修改ctdet任務使用的默認數據集爲新添加的數據集,如下(修改分辨率,類別數,均值,方差,數據集名字):
4.修改src/lib/utils/debugger.py文件
變成自己數據的類別和名字,前後數據集名字一定保持一致
再加上自己數據的類別,不包括背景__background__
二 訓練與測試:
1訓練:
輸入命令:
python main.py ctdet --exp_id coco_dla --batch_size 4 --master_batch 1 --lr 1.25e-4 --gpus 0,1
如果顯示顯存不夠之類的那種錯誤,需要在opts.py文件中將--num_workers改成0,batch_size小
2測試:
建立的plane.py中修改如下部分,加入if split == ‘test’:…,作用是當test時指定標籤文件爲之前建立的測試文件
運行test.py
python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /home/zbb/CenterNet/exp/ctdet/coco_dla/model_best.pth
結果:
其中,一般使用的是第二行,也就是IOU=0.5,全區域的AP值,其他的分別是不同IOU以及不同目標尺寸區域的結果。
3繪製loss曲線
訓練生成的日誌文件一般在exp/ctdet/../../logs.txt
參照博主但是,val—loss繪製不好,先繪製total—loss
import matplotlib.pyplot as plt
import numpy as np
def plot_loss_curve(log_file):
loss_data = open(log_file)
all_lines = loss_data.readlines()
print(all_lines[4].split(' '))
# losses
total_loss = [] # 4
hm_loss = [] # 7
wh_loss = [] # 10
off_loss = [] # 13
val_loss = [] # 19
spend_time = [] # 16
num_lines = len(all_lines)
for line in range(num_lines):
total_loss1 = all_lines[line].split(' ')[4]
hm_loss1 = all_lines[line].split(' ')[7]
wh_loss1 = all_lines[line].split(' ')[10]
off_loss1 = all_lines[line].split(' ')[13]
#val_loss1 = all_lines[line].split(' ')[19]
spend_time1 = all_lines[line].split(' ')[16]
print(total_loss1)
print(spend_time1)
total_loss.append(float(total_loss1))
#val_loss.append(float(val_loss1))
hm_loss.append(float(hm_loss1))
wh_loss.append(float(wh_loss1))
off_loss.append(float(off_loss1))
spend_time.append(float(spend_time1))
return total_loss
if __name__ == '__main__':
# 標準圖形繪製
# sns.set()
loss_res18 = plot_loss_curve(
'/home/zbb/CenterNet/exp/ctdet/coco_dla/logs_2019-10-17-15-41/log.txt') # 讀取訓練時生成的日誌文件
fig = plt.figure(figsize=(10, 4))
ax = fig.add_subplot(111)
ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1) # 這個label是圖線自己的標籤;
# ax.set_xlim([0, 800]) # 設置刻度;
# ax.set_xticks(range(0, 500, 100)) # 設置顯示的刻度;
# ax.set_yticklabels(['jan', 'feb', 'mar']) # 設置刻度標籤;
ax.set_xlabel('epochs') # 設置座標軸標籤;
ax.set_ylabel('loss_value')
ax.text(8750, 20, "plane", color='red') # 加入文本
ax.set_title('loss_of_CenterNet')
ax.legend(loc='best') # 將圖例擺放在不遮擋圖線的位置即可
ax.grid() # 添加網格
plt.savefig('/home/zbb/CenterNet/loss_of_CenterNet.png') # 保存文件到指定文件夾
plt.show()
total——loss結果圖: