使用caffe的python接口預測多張圖片

一、前言

       根據前面博文 使用lenet模型訓練及預測自己的圖片數據 可得到訓練得的caffemodel及其他相關的文件,回顧下My_FIle文件夾如下,predictPic文件夾中保存的是名爲“0“~“9“的文件夾,分別保存相應的0~9的多張字符圖片:


         使用classification.bin只能預測單張圖片,或者使用caffe.bin test   ***.prototxt  ***.caffemodel -iterations n的方法可以預測多張圖片,這裏則使用pyhton接口預測多張圖片的類別。需要使用My_File中的deploy.prototxt、caffemodel文件,均值文件mean.binaryproto。首先需要將mean.binaryproto轉換爲python接口需要的.npy文件,可參考博文http://blog.csdn.net/hyman_yx/article/details/51732656,轉換的mean.npy保存於My_File/Mean/。在caffe/python/下新建python文件predict_all.py,輸入以下內容。

#!/usr/bin/env python
#-*- coding:utf-8 -*-
import cv2
import numpy as np
import sys,os

import caffe

def GetFileList(dir, fileList):
    newDir = dir
    if os.path.isfile(dir):
        fileList.append(dir.decode('gbk'))
    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            newDir=os.path.join(dir,s)
            GetFileList(newDir, fileList)
    return fileList



if __name__=='__main__':
    caffe_root='/home/jyang/caffe/'
    sys.path.insert(0,caffe_root+'python')
    os.chdir(caffe_root)

    net_file=caffe_root+'My_File/Deploy/deploy.prototxt'
    caffe_model=caffe_root+'My_File/lenet_iter_10000.caffemodel'
    mean_file=caffe_root+'My_File/Mean/mean.npy'

    net=caffe.Net(net_file,caffe_model,caffe.TEST)
    transformer=caffe.io.Transformer({'data':net.blobs['data'].data.shape})
    transformer.set_transpose('data',(2,0,1))
    transformer.set_mean('data',np.load(mean_file).mean(1).mean(1))
    transformer.set_raw_scale('data',225)
    transformer.set_channel_swap('data',(2,1,0))

    imagenet_labels_filename=caffe_root+'My_File/Synset/synset_words.txt'
    labels=np.loadtxt(imagenet_labels_filename,str,delimiter='\t')

    MyPicList=GetFileList('My_File/predictPic',[])
    f=open('My_File/res.txt','w')
    for imgPath in MyPicList:
        img=caffe.io.load_image(caffe_root+imgPath)
        img=img[...,::-1]

        net.blobs['data'].data[...]=transformer.preprocess('data',img)
        out=net.forward()
        top_k=net.blobs['prob'].data[0].flatten().argsort()[-1:-5:-1]
        f.writelines(imgPath+' '+labels[top_k[0]]+'\n' )
    f.close()


二、結果

         將所有的預測結果保存至res.txt文件,準確率沒計算出來,最右邊爲該圖片的預測分類


正確率很高,訓練和測試的數字圖片都是身份證圖片上切分下來的號碼字符,當然也有錯分的情況,如下3被誤判爲1了:


三、參考博文

http://blog.csdn.net/u010142666/article/details/60469393



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章