level 1
楼主在学习TFRecords时,先将MNIST数据集导成TFRecords,按照《实战Google深度学习框架的写法》,将label 对应的feature导成映到int64(即Int64List), 生成完整的TFRecords,再去读这个文件,发现用了很长时间,然后试着将label中对应的feature改成映到bytes(即BytesList),再生成新的TFRecords,去读新的文件,用了一半时间不到。进一步实验,加入更多int64, bytes 的对比,发现时间差得越来越多,那么是不是说 tf 对 含int64List 的TFRecord 读取支持慢于 BytesList,那么 int64list 不就没什么用了? 还是说哪里弄错了?
2017年03月30日 13点03分
1
level 1
tfwiter.py
#!/usr/bin/pathon
# -*- coding:utf-8 -*-
# This code is to generate the TFRecode for MNIST ser.
# We devide it into three parts: train, validation and test.
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/MNIST/", dtype=tf.uint8, one_hot=True)
# Define the feature of TFRecode
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def oper(value):
# Input the data
images = eval("mnist." + value +".images")
labels = eval("mnist." + value +".labels")
pixels = images.shape[1]
num_examples = eval("mnist." + value +".num_examples")
filename = "./data/MNIST/" + value + ".tfrecords"
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features = tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(labels[index].argmax()),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
writer.close()
def main(argv=None):
oper("train")
oper("validation")
oper("test")
if __name__ == "__main__":
tf.app.run()
**************
tfread.py
#!/usr/bin/python
# -*- coding : utf-8 -*-
# Read the information in the TFRecords
import tensorflow as tf
import numpy as np
# num_epochs = 10
files = tf.train.match_filenames_once("./data/MNIST/*.tfrecords")
filename_queue = tf.train.string_input_producer(files, shuffle=True, num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'pixels': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
label, pixel = features['label'], features['pixels']
image = tf.decode_raw(features['image_raw'], tf.uint8)
def main(argc=None):
with tf.Session() as sess:
# tf.train.match_filenames_once function needs to initialize variables
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#index is the num of all the examples we have
index = 0
labels = []
pixels = []
images = []
try:
while coord.should_stop() is not True:
temp1, temp2, temp3= sess.run([[image], label, pixel])
index = index + 1
labels.append(temp2)
pixels.append(temp3)
images.append(temp1)
except tf.errors.OutOfRangeError:
print("Done input data")
finally:
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
tf.app.run()
***********
setup.py
#!/usr/bin/pathon
# -*- coding:utf-8 -*-
import tfread
#import tfwrite
if __name__ == "__main__":
tfread.main()
2017年03月30日 13点03分
2