JobPlus知识库 IT 工业智能4.0 文章
TensorFlow-Slim准备自己的数据集

TF-slim是一个新的TensorFlow轻量级高级API,可以用来搭建、训练和验证网络模型,最近由于项目需要,在学习使用该库。dataset库中包含下载标准数据集的代码,下面介绍如何在标准代码的基础上准备自己的数据进行训练。 
1. 准备自己的数据 
将要训练和测试的数据分别放在train和test文件夹下,文件夹下是以标签名命名的各子类数据,如下图所示

2.在datasets下创建自己待训练数据的脚本,比如我这里命名为car,则相应的准备car.py 和 download_and_convert_car.py两个脚本。 
car.py的脚本和datasets文件夹下的flowers.py等标准数据集脚本一样,只需要更改对应的类别数和样本数。

_FILE_PATTERN = 'car_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 12973, 'validation': 3200} 

_NUM_CLASSES = 3


download_and_convert_car.py和对应的脚本有区别,不需要下载和划分数据,只需要做数据转换即可,下面贴出代码。

#!/usr/bin/env python2

# -*- coding: utf-8 -*-

"""

Created on Wed May 30 09:53:21 2018 

 @author: liuli 

"""


from __future__ import absolute_import

from __future__ import division

from __future__ import print_function


import math

import os

import random

import sys


import tensorflow as tf


from datasets import dataset_utils


# Seed for repeatability.

_RANDOM_SEED = 0


# The number of shards per dataset split.

_NUM_SHARDS = 5


class ImageReader(object):   

"""Helper class that provides TensorFlow image coding utilities."""  


def __init__(self):

    # Initializes function that decodes RGB JPEG data.

    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)

    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)  


def read_image_dims(self, sess, image_data):

    image = self.decode_jpeg(sess, image_data)

    return image.shape[0], image.shape[1]


def decode_jpeg(self, sess, image_data):

    image = sess.run(self._decode_jpeg,

                     feed_dict={self._decode_jpeg_data: image_data})

    assert len(image.shape) == 3

    assert image.shape[2] == 3

    return image


def _get_filenames_and_classes(dataset_dir):

  directories = []

  class_names = []

  for filename in os.listdir(dataset_dir):

    path = os.path.join(dataset_dir, filename)

    if os.path.isdir(path):

      directories.append(path)

      class_names.append(filename)

  photo_filenames = []

  for directory in directories:

    for filename in os.listdir(directory):

      path = os.path.join(directory, filename)

      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)


def _get_dataset_filename(dataset_dir, split_name, shard_id):

  output_filename = 'car_%s_%05d-of-%05d.tfrecord' % (

      split_name, shard_id, _NUM_SHARDS)

  return os.path.join(dataset_dir, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):

  """Converts the given filenames to a TFRecord dataset.  Args:

    split_name: The name of the dataset, either 'train' or 'validation'.

    filenames: A list of absolute paths to png or jpg images.

    class_names_to_ids: A dictionary from class names (strings) to ids

      (integers).

    dataset_dir: The directory where the converted datasets are stored.

  """

  assert split_name in ['train', 'validation']

  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))

  with tf.Graph().as_default():

    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):

        output_filename = _get_dataset_filename(

            dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:

          start_ndx = shard_id * num_per_shard

          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))

          for i in range(start_ndx, end_ndx):

            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (

                i+1, len(filenames), shard_id))

            sys.stdout.flush()

            # Read the filename:

            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()

            height, width = image_reader.read_image_dims(sess, image_data)

            class_name = os.path.basename(os.path.dirname(filenames[i]))

            class_id = class_names_to_ids[class_name]

            example = dataset_utils.image_to_tfexample(

                image_data, b'jpg', height, width, class_id)

            tfrecord_writer.write(example.SerializeToString())

  sys.stdout.write('\n')

  sys.stdout.flush()

train_data_dir = '/home/liuli/work/Tensorflow/flower_data/raw-data/train'

test_data_dir = '/home/liuli/work/Tensorflow/flower_data/raw-data/validation'


def run(dataset_dir):

  """Runs the download and conversion operation.

  Args:

    dataset_dir: The dataset directory where the dataset is stored.

  """

  training_filenames,class_names = _get_filenames_and_classes(train_data_dir)

  class_names_to_ids = dict(zip(class_names, range(len(class_names))))

  random.seed(_RANDOM_SEED)

  random.shuffle(training_filenames)

  validation_filenames,_= _get_filenames_and_classes(test_data_dir)

  random.shuffle(validation_filenames)

  _convert_dataset('train', training_filenames, class_names_to_ids,

                   dataset_dir)

  _convert_dataset('validation', validation_filenames, class_names_to_ids,

                   dataset_dir)

  labels_to_class_names = dict(zip(range(len(class_names)), class_names))

  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

  print('\nFinished converting the car dataset!'


3.在download_and_convert_data.py 69行main函数中加入dataset_name选择代码

elif FLAGS.dataset_name == 'car':

    download_and_convert_car.run(FLAGS.dataset_dir)


4.在dataset_factory.py的datasets_map中相应插入自己训练数据的键值对

from datasets import car

import os

import tensorflow as tf 

slim = tf.contrib.slim 


datasets_map = {

    'cifar10': cifar10,

    'flowers': flowers,

    'imagenet': imagenet,

    'mnist': mnist,

     'car':car }


5.创建生成数据的脚本

DATASET_DIR=/home/liuli/work/Tensorflow/cars 

python download_and_convert_data.py \

    --dataset_name=car \

    --dataset_dir=${DATASET_DIR}


就可以在DATASET_DIR文件夹下生成对应的TFrecord格式的数据 


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
21人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序