發現用C++讀.mat文件需要matlab的依賴,本機沒有裝matlab,只能用最樸素的方式來存儲數據核讀入數據了。
1.講tensorflow的參數存入txt文件
存的模型與上一篇博客一致,只是這次改用txt。numpy自帶的寫入文件函數支持一維數組和二維數組的寫入,但是卷積核這個是四維的,且shape爲[卷積核高,卷積核寬,輸入通道數,輸出通道數],在這裏我把shape轉爲[輸出通道數,輸入通道數,卷積核高,卷積核寬]存儲,並將數據放大了1000倍,存爲整數。
代碼如下:
import numpy as np
import tensorflow as tf
def store_4d_array(kernel, filename):
# store the kernel
f = open(filename, 'w+')
shape = kernel.shape
num_out_channel = shape[3]
num_in_channel = shape[2]
num_width = shape[0]
f.write(str(num_out_channel) + ',' + str(num_in_channel) + ',' + str(num_width) + ',' + str(num_width) + '\n')
for index_out_channel in range(num_out_channel):
for index_in_channel in range(num_in_channel):
for index_row in range(num_width):
for index_col in range(num_width):
f.write(str(int(kernel[index_row][index_col][index_in_channel][index_out_channel] * 1000)))
if index_col == num_width - 1:
f.write('\n')
else:
f.write(',')
f.close()
def store_1d_2d_array(bias, filename):
# store the bias
bias = bias * 1000
bias = bias.astype(int)
np.savetxt(filename, bias, delimiter=',', fmt="%d")
if __name__ == "__main__":
with tf.Session() as sess:
# load the meta graph and weights
saver = tf.train.import_meta_graph('model_2\minist.ckpt-70.meta')
saver.restore(sess, tf.train.latest_checkpoint('model_2/'))
# get weighs
graph = tf.get_default_graph()
conv1_w = sess.run(graph.get_tensor_by_name('conv1/w:0'))
np.save("conv1_w", conv1_w)
store_4d_array(conv1_w, "weights/conv1_w.txt")
conv1_b = sess.run(graph.get_tensor_by_name('conv1/b:0'))
store_1d_2d_array(conv1_b, "weights/conv1_b.txt")
conv2_w = sess.run(graph.get_tensor_by_name('conv2/w:0'))
store_4d_array(conv2_w, "weights/conv2_w.txt")
conv2_b = sess.run(graph.get_tensor_by_name('conv2/b:0'))
store_1d_2d_array(conv2_b, "weights/conv2_b.txt")
fc1_w = sess.run(graph.get_tensor_by_name('fc1/w:0'))
store_1d_2d_array(fc1_w, "weights/fc1_w.txt")
fc1_b = sess.run(graph.get_tensor_by_name('fc1/b:0'))
store_1d_2d_array(fc1_b, "weights/fc1_b.txt")
fc2_w = sess.run(graph.get_tensor_by_name('fc2/w:0'))
store_1d_2d_array(fc2_w, "weights/fc2_w.txt")
fc2_b = sess.run(graph.get_tensor_by_name('fc2/b:0'))
store_1d_2d_array(fc2_b, "weights/fc2_b.txt")
2.C++讀取txt中的數據
上述過程的逆過程,因爲都是按行存儲,所以思路很簡單,上代碼:
#include <iostream> #include <vector> #include <string> #include <fstream> #include <sstream> using namespace std; void read_4d_array(string filename, vector<vector<vector<vector<int>>>>& kernel) { // 讀取卷積核 // 讀文件 ifstream infile(filename, ios::in); string line; getline(infile, line); cout << "卷積核結構:" << line << endl; // 存儲卷積核的信息 stringstream ss_line(line); string str; vector<int> line_array; // 按照逗號分割,存成整型的數據 while (getline(ss_line, str, ',')) { stringstream str_temp(str); int int_temp; str_temp >> int_temp; line_array.push_back(int_temp); } int num_out_channel = line_array[0]; int num_in_channel = line_array[1]; int num_width = line_array[2]; // 逐行讀取文件的信息,並存儲 for (int index_out_channel = 0; index_out_channel < num_out_channel; index_out_channel++) { // 用來存儲一個in_channel 卷積核 vector<vector<vector<int>>> one_in_kernel; for (int index_in_channel = 0; index_in_channel < num_in_channel; index_in_channel++) { // 用來存儲一個二維的卷積核 vector<vector<int>> one_kernel; for (int index_row = 0; index_row < num_width; index_row++) { getline(infile, line); stringstream tmp_line(line); // 用來存儲卷積核的一行 vector<int> tmp_int_line; while (getline(tmp_line, str, ',')) { stringstream str_temp(str); int int_temp; str_temp >> int_temp; tmp_int_line.push_back(int_temp); } one_kernel.push_back(tmp_int_line); } one_in_kernel.push_back(one_kernel); } kernel.push_back(one_in_kernel); } } void read_1d_array(string filename, vector<int>& bias) { // 讀取偏置項,偏置項一個數字佔一行 // 打開文件 ifstream infile(filename, ios::in); string line; while (getline(infile, line)) { stringstream tmp_line(line); int int_tmp; tmp_line >> int_tmp; bias.push_back(int_tmp); } } void read_2d_array(string filename, vector<vector<int>>& weights) { // 讀取二維存儲的文件,主要是矩陣乘的權重向量、 ifstream infile(filename, ios::in); string line; while (getline(infile, line)) { stringstream ss_line(line); string str; vector<int> line_array; // 按照逗號分割,存成整型的數據 while (getline(ss_line, str, ',')) { stringstream str_temp(str); int int_temp; str_temp >> int_temp; line_array.push_back(int_temp); } weights.push_back(line_array); } }