可視化: Python—MatPlotLib—折線圖帶子圖

文章目錄

圖示

在這裏插入圖片描述

代碼

import  matplotlib.pyplot as plt   
import numpy as np
def plot_epoch_for_performance_and_loss(model_name, res_dict):                                        
    """Function: 評價指標以及訓練集loss和epoch的關係曲線                                              
    - param:                                                                                          
        model_name: (str) 模型的名稱                                                                  
        res_dict: (dict) 包含loss, 表現隨epoch變化的list                                              
    """                                                                                               
    color = ['r', 'g', 'b', 'y']                                                                 
    shape = ['o', 'v', '^']                                                                           
    loss = res_dict['epoch_loss']                                                                     
    fig = plt.figure(figsize=(15,6)) # figsize指定給個圖大小(兩個數字分別表示橫軸縱軸)                 
    ax1 = fig.add_subplot(1, 2, 1) # 1行2列的圖,相當於四個圖,1是第一個                              
    ax2 = fig.add_subplot(1, 2, 2) # 1行2列的圖,相當於四個圖,3是第三個                              
    ax1.plot(np.arange(len(loss)), np.array(loss))                                                   
    ax1.set_xlabel("Epoch")                                                                               
    ax1.set_ylabel("Loss")                                                                                
                                                                                                      
    legend = []                                                                                       
    for idx, key in enumerate(list(res_dict.keys())):                                                  
        if 'loss' in key:                                                                             
            continue                                                                                  
        c = color[idx%len(color)]                                                                     
        s = shape[idx%len(shape)]                                                                     
        ax2.plot(np.arange(len(res_dict[key])), np.array(res_dict[key]), color=c, marker=s)            
        legend.append(key)                                                                            
    ax2.legend(legend)                                                                                
    ax2.set_xlabel("Epoch")                                                                               
    ax2.set_ylabel("Perfermance")                                                                         
                                                                                                      
    plt.show()                                                                                        
    plt.savefig(model_name + '.png')    

model_name = 'BertCrf'
res_dict = {
    'epoch_loss':[10,6,5,5,3,2,1],
    'train_f1':[1,2,3,4,5,6,7],
    'dev_f1':[1,3,5,4,5,1,2,3]
}     

plot_epoch_for_performance_and_loss(model_name, res_dict)                                                             
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章