Tensorflow TFRecord圖片

代碼:

import tensorflow as tf
import numpy as np
import cv2

img_width = 4
img_height = 3
img_channel = 3

img_path = "E:\\opencvPhoto\\photo\\count.png"
tfrecord_path = "F:\\tensorflow\\tfrecord\\test_tfrecord.tfrecords"

def create_img():
    img = np.zeros((img_height, img_width, img_channel), dtype=np.uint8)

    count = 0
    for row in range(img_height):
        for col in range(img_width):
            for channel in range(img_channel):
                if channel == 2:
                    # img[row, col, channel] = 255 # 生成紅色圖片
                    img[row, col, channel] = count
                    count += 1

    print(img)
    cv2.imwrite(img_path, img);

def create_TfRecord():
    writer = tf.python_io.TFRecordWriter(tfrecord_path)
    img = cv2.imread(img_path)
    # opencv修改圖片大小
    img_resize = cv2.resize(img, (img_width, img_height))
    # 將圖片轉化爲原生bytes
    img_raw = img_resize.tobytes()
    # 生成example
    example = tf.train.Example(features=tf.train.Features(feature={
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(0)])),
        'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    }))

    # 序列化字符串並寫入
    writer.write(example.SerializeToString())

def read_TfRecord(flie_list):
    # 根據文件名生成一個隊列
    filename_queue = tf.train.string_input_producer(flie_list)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # 解析TFRecords單個內容
    features = tf.parse_single_example(serialized_example, features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw': tf.FixedLenFeature([], tf.string)
    })

    # 獲得TFRecords中的img_raw並decode
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    # Tensor進行reshape
    img = tf.reshape(img, [img_height, img_width, img_channel])
    # 轉換成float32類型
    img = tf.cast(img, tf.float32)
    # 獲得TFRecords中的label並轉換成int64類型
    label = tf.cast(features['label'], tf.int64)
    # print(img, label)
    return img, label

if __name__ == "__main__":
    # 1.生成測試圖片
    # create_img()

    # 2.生成TFRecord
    # create_TfRecord()

    # 3.讀取TFrecord
    read_img, read_label = read_TfRecord([tfrecord_path])
    with tf.Session() as sess:
        # 定義一個線程協調器
        coord = tf.train.Coordinator()
        # 開啓讀文件線程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        img, label = sess.run([read_img, read_label])
        print(img)
        print(label)
        # 回收子線程
        coord.request_stop()
        coord.join(threads)



輸出:

# 1.生成測試圖片輸出
[[[ 0  0  0]
  [ 0  0  1]
  [ 0  0  2]
  [ 0  0  3]]

 [[ 0  0  4]
  [ 0  0  5]
  [ 0  0  6]
  [ 0  0  7]]

 [[ 0  0  8]
  [ 0  0  9]
  [ 0  0 10]
  [ 0  0 11]]]

# 3.讀取TFrecord輸出
[[[ 0.  0.  0.]
  [ 0.  0.  1.]
  [ 0.  0.  2.]
  [ 0.  0.  3.]]

 [[ 0.  0.  4.]
  [ 0.  0.  5.]
  [ 0.  0.  6.]
  [ 0.  0.  7.]]

 [[ 0.  0.  8.]
  [ 0.  0.  9.]
  [ 0.  0. 10.]
  [ 0.  0. 11.]]]
0
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章