JobPlus知识库 IT 工业智能4.0 文章
TensorFlow的数据导入方法。

简介

本文介绍TensorFlow的第二种数据导入方法。

为了保持高效,这种方法稍显繁琐。分为如下几个步骤: 
- 把所有样本写入二进制文件(只执行一次) 
- 创建Tensor,从二进制文件读取一个样本 
- 创建Tensor,从二进制文件随机读取一个mini-batch 
- 把mini-batchTensor传入网络作为输入节点。

二进制文件

使用tf.python_io.TFRecordWriter创建一个专门存储tensorflow数据的writer,扩展名为’.tfrecord’。 
该文件中依次存储着序列化的tf.train.Example类型的样本。

writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')

for i in range(0, 10):

    # 创建样本example

    # ...

    serialized = example.SerializeToString()   # 序列化

    writer.write(serialized)    # 写入文件

writer.close()


每一个example的feature成员变量是一个dict,存储一个样本的不同部分(例如图像像素+类标)。以下例子的样本中包含三个键a,b,c:

   # 创建样本example

    a_data = 0.618 + i         # float

    b_data = [2016 + i, 2017+i]     # int64

    c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i    # bytes

    c_data = c_data.astype(numpy.uint8)

    c_raw = c.tostring()             # 转化成字符串

    example = tf.train.Example(

        features=tf.train.Features(

            feature={

                'a': tf.train.Feature(

                    float_list=tf.train.FloatList(value=[a_data])   # 方括号表示输入为list

                ),

                'b': tf.train.Feature(

                    int64_list=tf.train.Int64List(value=b_data)    # b_data本身就是列表

                ),

                'c': tf.train.Feature(

                    bytes_list=tf.train.BytesList(value=[c_raw])

                )

            }

        )

    )


dict成员的值部分接受三种类型数据: 
- tf.train.FloatList:列表每个元素为float。例如a。 
- tf.train.Int64List:列表每个元素为int64。例如b。 
- tf.train.BytesList:列表每个元素为string。例如c。

第三种类型尤其适合图像样本。注意在转成字符串之前要设定为uint8类型。

读取一个样本

接下来,我们定义一个函数,创建“从文件中读一个样本”操作,返回结果Tensor。

def read_single_sample(filename):

    # 读取样本example的每个成员a,b,c

    # ...

    return a, b, c


首先创建读文件队列,使用tf.TFRecordReader从文件队列读入一个序列化的样本。

   # 读取样本example的每个成员a,b,c

    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)    # 不限定读取数量

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)


如果样本量很大,可以分成若干文件,把文件名列表传入tf.train.string_input_producer。 
和刚才的writer不同,这个reader是符号化的,只有在sess中run才会执行。

接下来解析符号化的样本

   # get feature from serialized example

    features = tf.parse_single_example(

        serialized_example,

        features={

            'a': tf.FixedLenFeature([], tf.float32),    #0D, 标量

            'b': tf.FixedLenFeature([2], tf.int64),   # 1D,长度为2

            'c': tf.FixedLenFeature([], tf.string)  # 0D, 标量

        }

    )

    a = features['a']

    b = features['b']

    c_raw = features['c']

    c = tf.decode_raw(c_raw, tf.uint8)

    c = tf.reshape(c, [2, 3])


对于BytesList,要重新进行解码,把string类型的0维Tensor变成uint8类型的1维Tensor。

读取mini-batch

使用tf.train.shuffle_batch将前述a,b,c随机化,获得mini-batchTensor:

a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)


使用

创建一个session并初始化:

# sess

sess = tf.Session()

init = tf.initialize_all_variables() 

sess.run(init) 

tf.train.start_queue_runners(sess=sess)


由于使用了读文件队列,所以要start_queue_runners。

每一次运行,会随机生成一个mini-batch样本:

a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) 

a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])


这样的mini-batch可以作为网络的输入节点使用。

总结

完整代码如下:

import tensorflow as tfimport numpydef write_binary():

    writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')

    for i in range(0, 2):

        a = 0.618 + i

        b = [2016 + i, 2017+i]

        c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i

        c = c.astype(numpy.uint8)

        c_raw = c.tostring()

        example = tf.train.Example(

            features=tf.train.Features(

                feature={

                    'a': tf.train.Feature(

                        float_list=tf.train.FloatList(value=[a])

                    ),

                    'b': tf.train.Feature(

                        int64_list=tf.train.Int64List(value=b)

                    ),

                    'c': tf.train.Feature(

                        bytes_list=tf.train.BytesList(value=[c_raw])

                    )

                }

            )

        )

        serialized = example.SerializeToString()

        writer.write(serialized)

    writer.close()def read_single_sample(filename):

    # output file name string to a queue

    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)

    # create a reader from file queue

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    # get feature from serialized example

    features = tf.parse_single_example(

        serialized_example,

        features={

            'a': tf.FixedLenFeature([], tf.float32),

            'b': tf.FixedLenFeature([2], tf.int64),

            'c': tf.FixedLenFeature([], tf.string)

        }

    )

    a = features['a']

    b = features['b']

    c_raw = features['c']

    c = tf.decode_raw(c_raw, tf.uint8)

    c = tf.reshape(c, [2, 3])

    return a, b, c

#-----main function-----if 1:

    write_binary()else:

    # create tensor

    a, b, c = read_single_sample('/tmp/data.tfrecord')

    a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=3, capacity=200, min_after_dequeue=100,

 num_threads=2)

    queues = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)

    # sess

    sess = tf.Session()

    init = tf.initialize_all_variables()

    sess.run(init)

    tf.train.start_queue_runners(sess=sess)

    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])

    print(a_val, b_val, c_val)

    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])

    print(a_val, b_val, c_val)


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
368人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序