#coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
mask_path = './tmp/masks/PAAPhoto_70S3A970B3FAV_PAAPhoto20191021184516_1001_-1_-1_2448_3264.png'
img_data = Image.open(mask_path)
annotation = np.array(img_data)
annotation = np.expand_dims(annotation, axis=0)
annotation = np.expand_dims(annotation, axis=3)
annotation = tf.cast(annotation, tf.int32)
print(annotation.shape)
def weight_variable(shape, stddev=0.02, name=None):
initial = tf.truncated_normal(shape, stddev=stddev)
if name is None:
return tf.Variable(initial)
else:
return tf.get_variable(name, initializer=initial)
def bias_variable(shape, name=None):
initial = tf.constant(0.0, shape=shape)
if name is None:
return tf.Variable(initial)
else:
return tf.get_variable(name, initializer=initial)
def conv2d_transpose_strided(x, W, b, output_shape=None, stride=2):
if output_shape is None:
output_shape = x.get_shape().as_list()
output_shape[1] *= 2
output_shape[2] *= 2
output_shape[3] = W.get_shape().as_list()[2]
conv = tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding="SAME")
return tf.nn.bias_add(conv, b)
kernel_19 = weight_variable(shape=[3, 3, shape_14[3].value, shape_19[3].value], name="kernel_19")
bais_19 = bias_variable(shape=[shape_14[3].value], name="bais_19")
up_19 = conv2d_transpose_strided(layer_19, kernel_19, bais_19, output_shape=tf.shape(layer_14))
add_19_14 = tf.add(up_19, layer_14, name="add_19_14")
bn_19_14 = tf.layers.batch_normalization(add_19_14, name="bn_19_14")
out_19_14 = tf.nn.relu(bn_19_14)
kernel_14 = weight_variable(shape=[3, 3, shape_7[3].value, shape_14[3].value], name="kernel14")
bais_14 = bias_variable(shape=[shape_7[3].value], name="bais_17")
up_14 = conv2d_transpose_strided(out_19_14, kernel_14, bais_14, output_shape=tf.shape(layer_7))
add_14_7 = tf.add(up_14, layer_7, name="add_14_7")
bn_14_7 = tf.layers.batch_normalization(add_14_7, name="bn_14_7")
out_14_7 = tf.nn.relu(bn_14_7)
kernel_7 = weight_variable(shape=[3, 3, shape_4[3].value, shape_7[3].value], name="kernel_7")
bais_7 = bias_variable(shape=[shape_4[3].value], name="bais_7")
up_7 = conv2d_transpose_strided(out_14_7, kernel_7, bais_7, output_shape=tf.shape(layer_4))
add_7_4 = tf.add(up_7, layer_4, name="add_7_4")
bn_7_4 = tf.layers.batch_normalization(add_7_4, name="bn_7_4")
out_7_4 = tf.nn.relu(bn_7_4)
kernel_4 = weight_variable(shape=[3, 3, shape_2[3].value, shape_4[3].value], name="kernel_4")
bais_4 = bias_variable(shape=[shape_2[3].value], name="bais_4")
up_4 = conv2d_transpose_strided(out_7_4, kernel_4, bais_4, output_shape=tf.shape(layer_2))
add_4_2 = tf.add(up_4, layer_2, name="add_4_2")
bn_4_2 = tf.layers.batch_normalization(add_4_2, name="bn_4_2")
out_4_2 = tf.nn.relu(bn_4_2)
kernel_2 = weight_variable(shape=[3, 3, 2, shape_2[3].value], name="kernel_2")
bais_2 = bias_variable(shape=[2], name="bais_2")
up_2 = conv2d_transpose_strided(out_4_2, kernel_2, bais_2, output_shape=[shape_2[0].value, shape_2[1].value * 2, shape_2[2].value * 2, 2])
masks = tf.argmax(up_2, dimension=3, name="mask")
global_step = tf.train.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=up_2,
labels=tf.squeeze(annotation, squeeze_dims=[3]),
name="entropy"))
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print("layer_19.shape: {}".format(layer_19.get_shape()))
up_19_value = sess.run(up_19)
print("up_19.shape: {}".format(up_19_value.shape))
add_19_14_value = sess.run(add_19_14)
print("add_19_14_value.shape: {}".format(add_19_14_value.shape))
out_19_14_value = sess.run(out_19_14)
print("out_19_14_value.shape: {}".format(out_19_14_value.shape))
up_14_value = sess.run(up_14)
print("up_14.shape: {}".format(up_14_value.shape))
add_14_7_value = sess.run(add_14_7)
print("add_14_7_value.shape: {}".format(add_14_7_value.shape))
out_14_7_value = sess.run(out_14_7)
print("out_14_7_value.shape: {}".format(out_14_7_value.shape))
up_7_value = sess.run(up_7)
print("up_7.shape: {}".format(up_7_value.shape))
add_7_4_value = sess.run(add_7_4)
print("add_7_4_value.shape: {}".format(add_7_4_value.shape))
out_7_4_value = sess.run(out_7_4)
print("out_7_4_value.shape: {}".format(out_7_4_value.shape))
up_4_value = sess.run(up_4)
print("up_4.shape: {}".format(up_4_value.shape))
add_4_2_value = sess.run(add_4_2)
print("add_4_2_value.shape: {}".format(add_4_2_value.shape))
out_4_2_value = sess.run(out_4_2)
print("out_4_2_value.shape: {}".format(out_4_2_value.shape))
#shape_0_value = sess.run(shape_0)
#print("shape_0: {}".format(shape_0_value))
up_2_value = sess.run(up_2)
print("up_2.shape: {}".format(up_2_value.shape))
mask_value = sess.run(masks)
print("mask_value.shape: {}".format(mask_value.shape))
print(mask_value[0, 50, :])
loss_value = sess.run(loss)
print("loss_value: {}".format(loss_value))
conv2d_transpose()測試
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.