简介
本文介绍如何将文本文件转换成tfrecord格式文件。
文件格式
每行为output_size + input_size个浮点数,前面output_size个浮点数表示输出,后面input_size个浮点数表示输入。
参数说明
- input_file: 输入文件
- train_file: 训练文件
- test_file: 测试文件
- input_size: 输入层大小
- output_size: 输出层大小
- test_data_ratio: 测试集占比
示例代码
def encode_to_tfrecords(input_file, train_file, test_file, input_size, output_size, test_data_ratio=0.2): train_writer = tf.python_io.TFRecordWriter(train_file) test_writer = tf.python_io.TFRecordWriter(test_file) with open(input_file, "r") as reader: for line in reader: splits = line.strip().split("\t") values = [float(_) for _ in splits] if len(values) < input_size + output_size: continue y = values[:output_size] x = values[output_size:] # 清除掉一些为nan的样本 if math.isnan(sum(values)): continue try: label = tf.train.Feature(float_list=tf.train.FloatList(value=y)) features = tf.train.Feature(float_list=tf.train.FloatList(value=x)) example = tf.train.Example( features=tf.train.Features( feature={ 'label': label, 'features': features } ) ) ratio = random.random() if ratio < test_data_ratio: test_writer.write(example.SerializeToString()) else: train_writer.write(example.SerializeToString()) except Exception as e: print(e) pass train_writer.close() test_writer.close()
train_writer = tf.python_io.TFRecordWriter("input/train.tfr") for index in range(len(labels)): label = tf.train.Feature(float_list=tf.train.FloatList(value=labels[index])) features = tf.train.Feature(int64_list=tf.train.Int64List(value=train[0][index])) example = tf.train.Example( features=tf.train.Features( feature={ 'label': label, 'features': features } ) ) train_writer.write(example.SerializeToString()) train_writer.close()
def get_batch_sample(filename, batch_size): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_sample = reader.read(filename_queue) sample = tf.parse_single_example( serialized_sample, features={ 'label': tf.FixedLenFeature([7], tf.float32), 'features': tf.FixedLenFeature([25], tf.int64), }) features = sample['features'] label = sample['label'] while True: features, label = tf.train.shuffle_batch( [features, label], batch_size=batch_size, capacity=5000, min_after_dequeue=1000, allow_smaller_final_batch=True, num_threads=10) features.set_shape([batch_size, 25]) label.set_shape([batch_size, 7]) yield features, label