示例
代碼
from sklearn.metrics import roc_curve, auc
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 10))
def plot_roc(labels, predict_probs, titles):
color = ['r', 'g', 'b', 'y']
shape = ['o', 'v', '^']
for idx, predict_prob in enumerate(predict_probs):
false_positive_rate,true_positive_rate,thresholds=roc_curve(labels, predict_prob)
roc_auc=auc(false_positive_rate, true_positive_rate)
plt.title('ROC')
c = color[idx%len(color)]
s = shape[idx%len(shape)]
plt.plot(false_positive_rate, true_positive_rate,'b',label='AUC K:{} = {:.4}'.format(titles[idx], roc_auc), color=c, marker=s, markevery=20)
plt.legend(loc='lower right')
plt.plot([0,1],[0,1],'r--')
plt.ylabel('TPR')
plt.xlabel('FPR')
plot_roc(pca_test_label, predict_probs)
解釋
該代碼參數含義爲:
- label: 長度爲N的列表,二分類的真實標籤
- predict_probs:二級列表,每個元素爲長度爲N的列表,記錄的是正類的概率。
- titles:圖例中的
K:xxx
的名稱列表