前言:授人以魚不如授人以漁.先學會用,在學原理,在學創造,可能一輩子用不到這種能力,但是不能不具備這種能力。這篇文章主要是介紹算法入門Helloword之手寫圖片識別模型java中如何實現以及部分解釋。目前大家對於人工智能-機器學習-神經網絡的文章都是基於python語言的,對於擅長java的後端小夥伴想要去了解就不是特別友好,所以這裏給大家介紹一下如何在java中實現,打開新世界的大門。以下爲本人個人理解如有錯誤歡迎指正
一、目標:使用MNIST數據集訓練手寫數字圖片識別模型
在實現一個模型的時候我們要準備哪些知識體系:
1.機器學習基礎:包括監督學習、無監督學習、強化學習等基本概念。
2.數據處理與分析:數據清洗、特徵工程、數據可視化等。
3.編程語言:如Python,用於實現機器學習算法。
4.數學基礎:線性代數、概率統計、微積分等數學知識。
5.機器學習算法:線性迴歸、決策樹、神經網絡、支持向量機等算法。
6.深度學習框架:如TensorFlow、PyTorch等,用於構建和訓練深度學習模型。
7.模型評估與優化:交叉驗證、超參數調優、模型評估指標等。
8.實踐經驗:通過實際項目和競賽積累經驗,不斷提升模型學習能力。
這裏的機器學習HelloWorld是手寫圖片識別用的是TensorFlow框架
主要需要:
1.理解手寫圖片的數據集,訓練集是什麼樣的數據(60000,28,28) 、訓練集的標籤是什麼樣的(1)
2.理解激活函數的作用
3.正向傳遞和反向傳播的作用以及實現
4.訓練模型和保存模型
5.加載保存的模型使用
二、java代碼與python代碼對比分析
因爲python代碼解釋網上已經有很多了,這裏不在重複解釋
1.數據集的加載
python中
def load_data(dpata_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim) # 數據集的維度
print(train_x.shape) # 數據集的形狀
print(len(train_x)) # 數據集的大小
print(train_x) # 數據集
print("---查看單個數據")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
print("---查看單個數據")
print(train_y[3])
java中
SimpleMnist.class
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
//加載數據
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
MnistDataset.class
/**
* @param trainingImagesArchive 訓練圖片路徑
* @param trainingLabelsArchive 訓練標籤路徑
* @param testImagesArchive 測試圖片路徑
* @param testLabelsArchive 測試標籤路徑
*/
public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
trainingImages.slice(sliceFrom(0));
trainingLabels.slice(sliceTo(0));
// 切片操作
Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
if (index >= 0) {
return new MnistDataset(trainingImages,trainingLabels,validationImage,validationLable,testImages,testLabels);
} else {
return null;
}
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static ByteNdArray readArchive(String archiveName) throws IOException {
System.out.println("archiveName = " + archiveName);
DataInputStream archiveStream = new DataInputStream(new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
);
archiveStream.readShort(); // first two bytes are always 0
byte magic = archiveStream.readByte();
if (magic != TYPE_UBYTE) {
throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
}
int numDims = archiveStream.readByte();
long[] dimSizes = new long[numDims];
int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
for (int i = 0; i < dimSizes.length; ++i) {
dimSizes[i] = archiveStream.readInt();
size *= dimSizes[i];
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
}
/**
* Mnist 數據集構造器
*/
private MnistDataset(ByteNdArray trainingImages, ByteNdArray trainingLabels,ByteNdArray validationImages,ByteNdArray validationLabels,ByteNdArray testImages,ByteNdArray testLabels
) {
this.trainingImages = trainingImages;
this.trainingLabels = trainingLabels;
this.validationImages = validationImages;
this.validationLabels = validationLabels;
this.testImages = testImages;
this.testLabels = testLabels;
this.imageSize = trainingImages.get(0).shape().size();
System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
System.out.println("數據集的維度:" + trainingImages.rank());
System.out.println("數據集的形狀 = " + trainingImages.shape());
System.out.println("數據集的大小 = " + trainingImages.shape().get(0));
System.out.println("查看單個數據 = " + trainingImages.get(0));
}
2.模型構建
python中
model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28))) # 添加Flatten層說明輸入數據的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu')) # 添加隱含層,爲全連接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax')) # 添加輸出層,爲全連接層,10個節點,softmax激活函數
print("打印模型結構")
# 使用 summary 打印模型結構
print('\n', model.summary()) # 查看網絡結構和參數信息
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
java中
SimpleMnist.class
Ops tf = Ops.create(graph);
// Create placeholders and variables, which should fit batches of an unknown number of images
//創建佔位符和變量,這些佔位符和變量應適合未知數量的圖像批次
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
// Create weights with an initial value of 0
// 創建初始值爲 0 的權重
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
// Create biases with an initial value of 0
//創建初始值爲 0 的偏置
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
// Predict the class of each image in the batch and compute the loss
//使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,並加上偏置項 biases。
//wx+b
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
//Softmax 是一個常用的激活函數,它將輸入轉換爲表示概率分佈的輸出。對於輸入向量中的每個元素,Softmax 函數會計算指數,
//並對所有元素求和,然後將每個元素的指數除以總和,最終得到一個概率分佈。這通常用於多分類問題,以輸出每個類別的概率
Softmax<TFloat32> softmax = tf.nn.softmax(add);
// 創建一個計算交叉熵的Mean對象
Mean<TFloat32> crossEntropy =
tf.math.mean( // 計算張量的平均值
tf.math.neg( // 計算張量的負值
tf.reduceSum( // 計算張量的和
tf.math.mul(labels, tf.math.log(softmax)), //計算標籤和softmax預測的對數乘積
tf.array(1) // 在指定軸上求和
)
),
tf.array(0) // 在指定軸上求平均值
);
// Back-propagate gradients to variables for training
//使用梯度下降優化器來最小化交叉熵損失函數。首先,創建了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
3.訓練模型
python中
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
java中
SimpleMnist.class
// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
// 創建會話運行器
session.runner()
// 添加要最小化的目標
.addTarget(minimize)
// 通過feed方法將圖像數據輸入到模型中
.feed(images.asOutput(), batchImages)
// 通過feed方法將標籤數據輸入到模型中
.feed(labels.asOutput(), batchLabels)
// 運行會話
.run();
}
}
4.模型評估
python中
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2) # 每次迭代輸出一條記錄,來評價該模型是否有比較好的泛化能力
print('Test 損失: %.3f' % test_loss)
print('Test 精確度: %.3f' % test_acc)
java中
SimpleMnist.class
// Test the model
ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
// 定義一個TFloat32類型的變量accuracyValue,用於存儲計算得到的準確率值
TFloat32 accuracyValue = (TFloat32) session.runner()
// 從會話中獲取準確率值
.fetch(accuracy)
.fetch(predicted)
.fetch(expected)
// 將images作爲輸入,testImages作爲數據進行餵養
.feed(images.asOutput(), testImages)
// 將labels作爲輸入,testLabels作爲數據進行餵養
.feed(labels.asOutput(), testLabels)
// 運行會話並獲取結果
.run()
// 獲取第一個結果並存儲在accuracyValue中
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
5.保存模型
python中
# 使用save_model保存完整模型
# save_model(model, '/media/cfs/用戶ERP名稱/ea/saved_model', save_format='pb')
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')
java中
SimpleMnist.class
// 保存模型
SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
6.加載模型
python中
# 加載.pb模型文件
global load_model
load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
load_model.summary()
demo = tensorflow.reshape(test_x, (1, 28, 28))
input_data = np.array(demo) # 準備你的輸入數據
input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
predictValue = load_model.predict(input_data)
print("predictValue")
print(predictValue)
y_pred = np.argmax(predictValue)
print('標籤值:' + str(test_y) + '\n預測值:' + str(y_pred))
return y_pred, test_y,
java中
SimpleMnist.class
//加載模型並預測
public void loadModel(String exportDir) {
// load saved model
SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
try {
printSignature(model);
} catch (Exception e) {
throw new RuntimeException(e);
}
ByteNdArray validationImages = dataset.getValidationImages();
ByteNdArray validationLabels = dataset.getValidationLabels();
TFloat32 testImages = preprocessImages(validationImages);
System.out.println("testImages = " + testImages.shape());
TFloat32 testLabels = preprocessLabels(validationLabels);
System.out.println("testLabels = " + testLabels.shape());
Result run = model.session().runner()
.feed("Placeholder:0", testImages)
.feed("Placeholder_1:0", testLabels)
.fetch("ArgMax:0")
.fetch("ArgMax_1:0")
.fetch("Mean_1:0")
.run();
// 處理輸出
Optional<Tensor> tensor1 = run.get("ArgMax:0");
Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
Optional<Tensor> tensor3 = run.get("Mean_1:0");
TInt64 predicted = (TInt64) tensor1.get();
Long predictedValue = predicted.getObject(0);
System.out.println("predictedValue = " + predictedValue);
TInt64 expected = (TInt64) tensor2.get();
Long expectedValue = expected.getObject(0);
System.out.println("expectedValue = " + expectedValue);
TFloat32 accuracy = (TFloat32) tensor3.get();
System.out.println("accuracy = " + accuracy.getFloat());
}
//打印模型信息
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = model.metaGraphDef();
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
}
三、完整的python代碼
本工程使用環境爲
Python: 3.7.9
https://www.python.org/downloads/windows/
Anaconda: Python 3.11 Anaconda3-2023.09-0-Windows-x86_64
https://www.anaconda.com/download#downloads
tensorflow:2.0.0
直接從anaconda下安裝
mnistTrainDemo.py
import gzip
import os.path
import tensorflow as tensorflow
from tensorflow import keras
# 可視化 image
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import save_model
# 加載數據
# mnist = keras.datasets.mnist
# mnistData = mnist.load_data() #Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- unknown url type: https
"""
這裏可以直接使用
mnist = keras.datasets.mnist
mnistData = mnist.load_data() 加載數據,但是有的時候不成功,所以使用本地加載數據
"""
def load_data(data_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim) # 數據集的維度
print(train_x.shape) # 數據集的形狀
print(len(train_x)) # 數據集的大小
print(train_x) # 數據集
print("---查看單個數據")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
# 可視化image圖片、一副image的數據
# plt.imshow(train_x[0].reshape(28, 28), cmap="binary")
# plt.show()
print("---查看單個數據")
print(train_y[0])
# 數據預處理
# 歸一化、並轉換爲tensor張量,數據類型爲float32. ---歸一化也可能造成識別率低
# train_x, test_x = tensorflow.cast(train_x / 255.0, tensorflow.float32), tensorflow.cast(test_x / 255.0,
# tensorflow.float32),
# train_y, test_y = tensorflow.cast(train_y, tensorflow.int16), tensorflow.cast(test_y, tensorflow.int16)
# print("---查看單個數據歸一後的數據")
# print(train_x[0][6]) # 30/255=0.11764706 ---歸一化每個值除以255
# print(train_y[0])
# Step2: 配置網絡 建立模型
'''
以下的代碼判斷就是定義一個簡單的多層感知器,一共有三層,
兩個大小爲100的隱層和一個大小爲10的輸出層,因爲MNIST數據集是手寫0到9的灰度圖像,
類別有10個,所以最後的輸出大小是10。最後輸出層的激活函數是Softmax,
所以最後的輸出層相當於一個分類器。加上一個輸入層的話,
多層感知器的結構是:輸入層-->>隱層-->>隱層-->>輸出層。
激活函數 https://zhuanlan.zhihu.com/p/337902763
'''
# 構造模型
# model = keras.Sequential([
# # 在第一層的網絡中,我們的輸入形狀是28*28,這裏的形狀就是圖片的長度和寬度。
# keras.layers.Flatten(input_shape=(28, 28)),
# # 所以神經網絡有點像濾波器(過濾裝置),輸入一組28*28像素的圖片後,輸出10個類別的判斷結果。那這個128的數字是做什麼用的呢?
# # 我們可以這樣想象,神經網絡中有128個函數,每個函數都有自己的參數。
# # 我們給這些函數進行一個編號,f0,f1…f127 ,我們想的是當圖片的像素一一帶入這128個函數後,這些函數的組合最終輸出一個標籤值,在這個樣例中,我們希望它輸出09 。
# # 爲了得到這個結果,計算機必須要搞清楚這128個函數的具體參數,之後才能計算各個圖片的標籤。這裏的邏輯是,一旦計算機搞清楚了這些參數,那它就能夠認出不同的10個類別的事物了。
# keras.layers.Dense(100, activation=tensorflow.nn.relu),
# # 最後一層是10,是數據集中各種類別的代號,數據集總共有10類,這裏就是10 。
# keras.layers.Dense(10, activation=tensorflow.nn.softmax)
# ])
model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28))) # 添加Flatten層說明輸入數據的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu')) # 添加隱含層,爲全連接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax')) # 添加輸出層,爲全連接層,10個節點,softmax激活函數
print("打印模型結構")
# 使用 summary 打印模型結構
# print(model.summary())
print('\n', model.summary()) # 查看網絡結構和參數信息
'''
接着是配置模型,在這一步,我們需要指定模型訓練時所使用的優化算法與損失函數,
此外,這裏我們也可以定義計算精度相關的API。
優化器https://zhuanlan.zhihu.com/p/27449596
'''
# 配置模型 配置模型訓練方法
# 設置神經網絡的優化器和損失函數。# 使用Adam算法進行優化 # 使用CrossEntropyLoss 計算損失 # 使用Accuracy 計算精度
# model.compile(optimizer=tensorflow.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# adam算法參數採用keras默認的公開參數,損失函數採用稀疏交叉熵損失函數,準確率採用稀疏分類準確率函數
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
# Step3:模型訓練
# 開始模型訓練
# model.fit(x_train, # 設置訓練數據集
# y_train,
# epochs=5, # 設置訓練輪數
# batch_size=64, # 設置 batch_size
# verbose=1) # 設置日誌打印格式
# 批量訓練大小爲64,迭代5次,測試集比例0.2(48000條訓練集數據,12000條測試集數據)
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
# STEP4: 模型評估
# 評估模型,不輸出預測結果輸出損失和精確度. test_loss損失,test_acc精確度
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2) # 每次迭代輸出一條記錄,來評價該模型是否有比較好的泛化能力
# model.evaluate(test_dataset, verbose=1)
print('Test 損失: %.3f' % test_loss)
print('Test 精確度: %.3f' % test_acc)
# 結果可視化
print(history.history)
loss = history.history['loss'] # 訓練集損失
val_loss = history.history['val_loss'] # 測試集損失
acc = history.history['sparse_categorical_accuracy'] # 訓練集準確率
val_acc = history.history['val_sparse_categorical_accuracy'] # 測試集準確率
plt.figure(figsize=(10, 3))
plt.subplot(121)
plt.plot(loss, color='b', label='train')
plt.plot(val_loss, color='r', label='test')
plt.ylabel('loss')
plt.legend()
plt.subplot(122)
plt.plot(acc, color='b', label='train')
plt.plot(val_acc, color='r', label='test')
plt.ylabel('Accuracy')
plt.legend()
# 暫停5秒關閉畫布,否則畫布一直打開的同時,會持續佔用GPU內存
# plt.ion() # 打開交互式操作模式
# plt.show()
# plt.pause(5)
# plt.close()
# plt.show()
# Step5:模型預測 輸入測試數據,輸出預測結果
for i in range(1):
num = np.random.randint(1, 10000) # 在1~10000之間生成隨機整數
plt.subplot(2, 5, i + 1)
plt.axis('off')
plt.imshow(test_x[num], cmap='gray')
demo = tensorflow.reshape(test_x[num], (1, 28, 28))
y_pred = np.argmax(model.predict(demo))
plt.title('標籤值:' + str(test_y[num]) + '\n預測值:' + str(y_pred))
# plt.show()
'''
保存模型
訓練好的模型可以用於加載後對新輸入數據進行預測,所以需要先進行保存已訓練模型
'''
#使用save_model保存完整模型
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')
mnistPredictDemo.py
import numpy as np
import tensorflow as tensorflow
import gzip
import os.path
from tensorflow.keras.models import load_model
# 預測
def predict(test_x, test_y):
test_x, test_y = test_x, test_y
'''
五、模型評估
需要先加載已訓練模型,然後用其預測新的數據,計算評估指標
'''
# 模型加載
# 加載.pb模型文件
global load_model
# load_model = load_model('./saved_model')
load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
load_model.summary()
# make a prediction
print("test_x")
print(test_x)
print(test_x.ndim)
print(test_x.shape)
demo = tensorflow.reshape(test_x, (1, 28, 28))
input_data = np.array(demo) # 準備你的輸入數據
input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
# test_x = tensorflow.cast(test_x / 255.0, tensorflow.float32)
# test_y = tensorflow.cast(test_y, tensorflow.int16)
predictValue = load_model.predict(input_data)
print("predictValue")
print(predictValue)
y_pred = np.argmax(predictValue)
print('標籤值:' + str(test_y) + '\n預測值:' + str(y_pred))
return y_pred, test_y,
def load_data(data_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print(train_x[0])
predict(train_x[0], train_y)
四、完整的java代碼
tensorflow 需要的java 版本對應表: https://github.com/tensorflow/java/#tensorflow-version-support
本工程使用環境爲
jdk版本:openjdk-21
pom依賴如下:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-framework</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
</dependencies>
<repositories>
<repository>
<id>tensorflow-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
數據集創建和解析類
MnistDataset.class
package org.example.tensorDemo.datasets.mnist;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.ImageBatchIterator;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;
public class MnistDataset {
public static final int NUM_CLASSES = 10;
private static final int TYPE_UBYTE = 0x08;
/**
* 訓練圖片字節類型的多維數組
*/
private final ByteNdArray trainingImages;
/**
* 訓練標籤字節類型的多維數組
*/
private final ByteNdArray trainingLabels;
/**
* 驗證圖片字節類型的多維數組
*/
public final ByteNdArray validationImages;
/**
* 驗證標籤字節類型的多維數組
*/
public final ByteNdArray validationLabels;
/**
* 測試圖片字節類型的多維數組
*/
private final ByteNdArray testImages;
/**
* 測試標籤字節類型的多維數組
*/
private final ByteNdArray testLabels;
/**
* 圖片的大小
*/
private final long imageSize;
/**
* Mnist 數據集構造器
*/
private MnistDataset(
ByteNdArray trainingImages,
ByteNdArray trainingLabels,
ByteNdArray validationImages,
ByteNdArray validationLabels,
ByteNdArray testImages,
ByteNdArray testLabels
) {
this.trainingImages = trainingImages;
this.trainingLabels = trainingLabels;
this.validationImages = validationImages;
this.validationLabels = validationLabels;
this.testImages = testImages;
this.testLabels = testLabels;
//第一個圖像的形狀,並返回其尺寸大小。每一張圖片包含28X28個像素點 所以應該爲784
this.imageSize = trainingImages.get(0).shape().size();
// System.out.println("imageSize = " + imageSize);
// System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
// System.out.println("數據集的維度:" + trainingImages.rank());
// System.out.println("數據集的形狀 = " + trainingImages.shape());
// System.out.println("數據集的大小 = " + trainingImages.shape().get(0));
// System.out.println("數據集 = ");
// for (int i = 0; i < trainingImages.shape().get(0); i++) {
// for (int j = 0; j < trainingImages.shape().get(1); j++) {
// for (int k = 0; k < trainingImages.shape().get(2); k++) {
// System.out.print(trainingImages.getObject(i, j, k) + " ");
// }
// System.out.println();
// }
// System.out.println();
// }
// System.out.println("查看單個數據 = " + trainingImages.get(0));
// for (int j = 0; j < trainingImages.shape().get(1); j++) {
// for (int k = 0; k < trainingImages.shape().get(2); k++) {
// System.out.print(trainingImages.getObject(0, j, k) + " ");
// }
// System.out.println();
// }
// System.out.println("查看單個數據大小 = " + trainingImages.get(0).size());
// System.out.println("查看trainingImages三維數組下的第一個元素的第二個二維數組大小 = " + trainingImages.get(0).get(1).size());
// System.out.println("查看trainingImages三維數組下的第一個元素的第7個二維數組的第8個元素 = " + trainingImages.getObject(0, 6, 8));
// System.out.println("trainingLabels = " + trainingLabels.getObject(1));
}
/**
* @param validationSize 驗證的數據
* @param trainingImagesArchive 訓練圖片路徑
* @param trainingLabelsArchive 訓練標籤路徑
* @param testImagesArchive 測試圖片路徑
* @param testLabelsArchive 測試標籤路徑
*/
public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive,
String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
if (validationSize > 0) {
return new MnistDataset(
trainingImages.slice(sliceFrom(validationSize)),
trainingLabels.slice(sliceFrom(validationSize)),
trainingImages.slice(sliceTo(validationSize)),
trainingLabels.slice(sliceTo(validationSize)),
testImages,
testLabels
);
}
return new MnistDataset(trainingImages, trainingLabels, null, null, testImages, testLabels);
} catch (IOException e) {
throw new AssertionError(e);
}
}
/**
* @param trainingImagesArchive 訓練圖片路徑
* @param trainingLabelsArchive 訓練標籤路徑
* @param testImagesArchive 測試圖片路徑
* @param testLabelsArchive 測試標籤路徑
*/
public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,
String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
trainingImages.slice(sliceFrom(0));
trainingLabels.slice(sliceTo(0));
// 切片操作
Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
if (index >= 0) {
return new MnistDataset(
trainingImages,
trainingLabels,
validationImage,
validationLable,
testImages,
testLabels
);
} else {
return null;
}
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static ByteNdArray readArchive(String archiveName) throws IOException {
System.out.println("archiveName = " + archiveName);
DataInputStream archiveStream = new DataInputStream(
//new GZIPInputStream(new java.io.FileInputStream("src/main/resources/"+archiveName))
new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
);
//todo 不知道怎麼讀取和實際的內部結構
archiveStream.readShort(); // first two bytes are always 0
byte magic = archiveStream.readByte();
if (magic != TYPE_UBYTE) {
throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
}
int numDims = archiveStream.readByte();
long[] dimSizes = new long[numDims];
int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
for (int i = 0; i < dimSizes.length; ++i) {
dimSizes[i] = archiveStream.readInt();
size *= dimSizes[i];
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
}
public Iterable<ImageBatch> trainingBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, trainingImages, trainingLabels);
}
public Iterable<ImageBatch> validationBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels);
}
public Iterable<ImageBatch> testBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, testImages, testLabels);
}
public ImageBatch testBatch() {
return new ImageBatch(testImages, testLabels);
}
public long imageSize() {
return imageSize;
}
public long numTrainingExamples() {
return trainingLabels.shape().size(0);
}
public long numTestingExamples() {
return testLabels.shape().size(0);
}
public long numValidationExamples() {
return validationLabels.shape().size(0);
}
public ByteNdArray getValidationImages() {
return validationImages;
}
public ByteNdArray getValidationLabels() {
return validationLabels;
}
}
SimpleMnist.class
package org.example.tensorDemo.dense;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.mnist.MnistDataset;
import org.tensorflow.*;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.nn.Softmax;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
public class SimpleMnist implements Runnable {
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
public static void main(String[] args) {
//加載數據集
// MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
// TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
//創建了一個名爲graph的圖形對象。
try (Graph graph = new Graph()) {
SimpleMnist mnist = new SimpleMnist(graph, validationDataset);
mnist.run();//構建和訓練模型
mnist.loadModel("D:\\ai\\ai-demo");
}
}
@Override
public void run() {
Ops tf = Ops.create(graph);
// Create placeholders and variables, which should fit batches of an unknown number of images
//創建佔位符和變量,這些佔位符和變量應適合未知數量的圖像批次
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
// Create weights with an initial value of 0
// 創建初始值爲 0 的權重
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
// Create biases with an initial value of 0
//創建初始值爲 0 的偏置
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
// Predict the class of each image in the batch and compute the loss
//使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,並加上偏置項 biases。
//wx+b
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
//Softmax 是一個常用的激活函數,它將輸入轉換爲表示概率分佈的輸出。對於輸入向量中的每個元素,Softmax 函數會計算指數,
//並對所有元素求和,然後將每個元素的指數除以總和,最終得到一個概率分佈。這通常用於多分類問題,以輸出每個類別的概率
//激活函數
Softmax<TFloat32> softmax = tf.nn.softmax(add);
// 創建一個計算交叉熵的Mean對象
//損失函數
Mean<TFloat32> crossEntropy =
tf.math.mean( // 計算張量的平均值
tf.math.neg( // 計算張量的負值
tf.reduceSum( // 計算張量的和
tf.math.mul(labels, tf.math.log(softmax)), //計算標籤和softmax預測的對數乘積
tf.array(1) // 在指定軸上求和
)
),
tf.array(0) // 在指定軸上求平均值
);
// Back-propagate gradients to variables for training
//使用梯度下降優化器來最小化交叉熵損失函數。首先,創建了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
//梯度下降 https://www.cnblogs.com/guoyaohua/p/8542554.html
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
// Compute the accuracy of the model
//使用 argMax 函數找出在給定軸上張量中最大值的索引,
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
//使用 equal 函數比較模型預測的標籤和實際標籤是否相等,再用 cast 函數將布爾值轉換爲浮點數,最後使用 mean 函數計算準確率。
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
// Run the graph
try (Session session = new Session(graph)) {
// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
System.out.println("batchImages = " + batchImages.shape());
System.out.println("batchLabels = " + batchLabels.shape());
// 創建會話運行器
session.runner()
// 添加要最小化的目標
.addTarget(minimize)
// 通過feed方法將圖像數據輸入到模型中
.feed(images.asOutput(), batchImages)
// 通過feed方法將標籤數據輸入到模型中
.feed(labels.asOutput(), batchLabels)
// 運行會話
.run();
}
}
// Test the model
ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
// 定義一個TFloat32類型的變量accuracyValue,用於存儲計算得到的準確率值
TFloat32 accuracyValue = (TFloat32) session.runner()
// 從會話中獲取準確率值
.fetch(accuracy)
.fetch(predicted)
.fetch(expected)
// 將images作爲輸入,testImages作爲數據進行餵養
.feed(images.asOutput(), testImages)
// 將labels作爲輸入,testLabels作爲數據進行餵養
.feed(labels.asOutput(), testLabels)
// 運行會話並獲取結果
.run()
// 獲取第一個結果並存儲在accuracyValue中
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
// 保存模型
SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static final int VALIDATION_SIZE = 5;
private static final int TRAINING_BATCH_SIZE = 100;
private static final float LEARNING_RATE = 0.2f;
private static TFloat32 preprocessImages(ByteNdArray rawImages) {
Ops tf = Ops.create();
// Flatten images in a single dimension and normalize their pixels as floats.
long imageSize = rawImages.get(0).shape().size();
return tf.math.div(
tf.reshape(
tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
tf.array(-1L, imageSize)
),
tf.constant(255.0f)
).asTensor();
}
private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
Ops tf = Ops.create();
// Map labels to one hot vectors where only the expected predictions as a value of 1.0
return tf.oneHot(
tf.constant(rawLabels),
tf.constant(MnistDataset.NUM_CLASSES),
tf.constant(1.0f),
tf.constant(0.0f)
).asTensor();
}
private final Graph graph;
private final MnistDataset dataset;
private SimpleMnist(Graph graph, MnistDataset dataset) {
this.graph = graph;
this.dataset = dataset;
}
public void loadModel(String exportDir) {
// load saved model
SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
try {
printSignature(model);
} catch (Exception e) {
throw new RuntimeException(e);
}
ByteNdArray validationImages = dataset.getValidationImages();
ByteNdArray validationLabels = dataset.getValidationLabels();
TFloat32 testImages = preprocessImages(validationImages);
System.out.println("testImages = " + testImages.shape());
TFloat32 testLabels = preprocessLabels(validationLabels);
System.out.println("testLabels = " + testLabels.shape());
Result run = model.session().runner()
.feed("Placeholder:0", testImages)
.feed("Placeholder_1:0", testLabels)
.fetch("ArgMax:0")
.fetch("ArgMax_1:0")
.fetch("Mean_1:0")
.run();
// 處理輸出
Optional<Tensor> tensor1 = run.get("ArgMax:0");
Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
Optional<Tensor> tensor3 = run.get("Mean_1:0");
TInt64 predicted = (TInt64) tensor1.get();
Long predictedValue = predicted.getObject(0);
System.out.println("predictedValue = " + predictedValue);
TInt64 expected = (TInt64) tensor2.get();
Long expectedValue = expected.getObject(0);
System.out.println("expectedValue = " + expectedValue);
TFloat32 accuracy = (TFloat32) tensor3.get();
System.out.println("accuracy = " + accuracy.getFloat());
}
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = model.metaGraphDef();
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
System.out.println("-----------------------------------------------");
}
}
五、最後兩套代碼運行結果
六、待完善點
1、這裏並沒有對提供web服務輸入圖片以及圖片數據二值話等進行處理。有興趣的小夥伴可以自己進行嘗試
2、並沒有使用卷積神經網絡等,只是用了wx+b和激活函數進行跳躍,以及階梯下降算法和交叉熵
3、沒有進行更多層級的設計等