Tensorflow實現最近鄰

import tensorflow as tf
import numpy as np

#導入mnist數據
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("data/mnist",one_hot=True)

#選取訓練集、測試集數目
X_train,Y_train=mnist.train.next_batch(50000)
X_test,Y_test=mnist.test.next_batch(500)

#定義變量大小()
xtr=tf.placeholder("float", [None,784])
xte=tf.placeholder("float", [784])


#計算測試數據與訓練數據L1範數大小(1表示從橫軸進行降維)

distance=tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte, ))), 1)

#求得distance最小的下標(0表示從豎軸計算)
predict=tf.arg_min(distance, 0)

#準確率初始0

Accuracy=0


#數據初始化
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)

#開始預測
for i in range(len(X_test)):
   
    #近鄰算法:測試集與訓練集對比,返回誤差最小的下標
    nn_index=sess.run(predict,feed_dict={xtr:X_train,xte:X_test[i,:]})
    
    #np.argmax  返回標籤Y中最大數下標(既數值爲1的下標),也就是該標籤所對應的數字
    print("Test :",i,"Prection :",np.argmax(Y_train[nn_index]),"True class :",np.argmax(Y_test[i]))
    
    #統計準確率
    if np.argmax(Y_train[nn_index])==np.argmax(Y_test[i]):
        
        Accuracy+=1/len(X_test)
        
print("Accuracy :",Accuracy)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章