import tensorflow as tf
import pandas as pd
import time
sess = tf.InteractiveSession()
def gen():
csv_data = pd.read_csv('../dataset/train.csv')
length = csv_data.shape[0]
for i in range(length):
[id, tokens, label] = csv_data.iloc[i]
yield (id, tokens, label)
def create_dataset():
data = tf.data.Dataset.from_generator(gen, (tf.int32, tf.string, tf.string))
data = data.batch(32)
data = data.prefetch(32)
data = data.make_one_shot_iterator()
get_id, get_tokens, get_label = data.get_next()
while True:
try:
id, tokens, label = sess.run([get_id, get_tokens, get_label])
yield id, tokens, label
except:
break
data = create_dataset()
for i, (id, tokens, label) in enumerate(data):
# print(id, tokens, label)
print(id.shape[0])
print('----------------')