Dataset and Iterator in Tensorflow data processing

Complete machine learning implementation code GitHub
Welcome to reprint, please indicate the source

0. Contents


1. Tensorflow efficient Pipeline

2. Dataset and Iterator in tensorflow data processing

3. Tensorflow generates TFRecord

4. Tensorflow's Estimator practice principle

Back to the top

1. Preface

When we train the model, the first step we must go through is data processing. In the field of machine learning, there is a saying that the quality of data processing directly affects the results of the model. Data processing is a crucial step.

Today, we focus on another problem of data processing: suppose we do in-depth learning, the amount of data will easily reach the GB level, and the speed of data processing is also very important for model training. A common situation is that the time of data processing accounts for most of the training of the whole model.

What we are introducing today is that the data processing method officially recommended by Tensorflow is to use the Dataset API to support reading from memory and hard disk at the same time. Compared with the previous two methods, it is more concise and easy to understand in syntax

Back to the top

2. Dataset principle

The class diagram in the Dataset API officially given by Google is as follows:

2.1 Dataset creation method

The Dataset API also provides four ways to create datasets:

  • this function reads data directly from memory in the form of array, matrix, dict, etc.
dataset =[1.0, 2.0, 3.0, 4.0, 5.0]))
#Instantiate make_one_shot_iterator object, which can only be read once
iterator = dataset.make_one_shot_iterator()
# Take an element from the iterator
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
  • as the name suggests, this function is used to read TFRecord files. Each element in the dataset is a TFExample.
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset =
  • the input of this function is a list of files and the output is a dataset. Each element in the dataset corresponds to a line in the file. You can use this function to read in CSV files.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset =
  • the input of this function is a list of files and a record_bytes, and then each element of the dataset is the fixed byte record in the file_ Bytes. It is usually used to read files saved in binary form, such as CIFAR10 dataset.

2.2 transformation of dataset data

A Dataset becomes a new Dataset through Transformation. Generally, we can complete a series of operations through Transformation, such as data Transformation, disrupting, forming batch, generating epoch, etc. the commonly used transformations are:

  • map: receive a function object. Each element in the dataset will be used as the input of this function, and the return value of the function will be used as a new dataset. For example, we can add 1 to the value of each element in the dataset.
dataset =[1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
  • Apply: apply a conversion function to the dataset.
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
  • Batch: combine the elements into a batch according to the received integer value. For example, the following program combines the elements in the dataset into a batch with a size of 32.
dataset = dataset.batch(32)
  • shuffle: scramble the elements in the dataset. It has a parameter buffersize, which indicates the size of the buffer used in the scramble.
dataset = dataset.shuffle(buffer_size=10000)
  • Repeat: the whole sequence is repeated many times, which is mainly used to process epochs in machine learning. Assuming that the original data is an epoch, it can be changed into five epochs using repeat(5).
dataset = dataset.repeat(5)
# If repeat has no parameters, the cyclic data is repeated all the time
dataset = dataset.repeat()
  • padded_batch: pad the data in the dataset to a certain length.
        tf.TensorShape([None]),  # src
        tf.TensorShape([]),  # tgt_output
        tf.TensorShape([src_max_len])),  # src_len
        src_eos_id,  # src
        0,  # tgt_len -- unused
        0,  # src_len -- unused
        0)) # mask
  • shard: perform slice operation according to multiple GPU s.
dataset.shard(num_shards, shard_index)

More complete code for generating dataset.

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])
  image = _augment_helper(image)  # augments image using slice, reshape, resize_bilinear
  return image, parsed["label"]

#Simple generate input_fn
def input_fn():
files ="/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset =
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

Back to the top

3. Iterator principle

3.1 Iterator Init initialization

There are four types of iterators, with increasing complexity. I think it is enough to master the first two. Iterator also has an advantage. At present, single iterator is the only type that is easy to use with Estimator.

  • One shot Iterator: one shot Iterator is the simplest type of Iterator. It only supports accessing the entire data set once without explicit initialization. One shot Iterator does not support parameterization.
dataset =
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
value =
assert i == value

  • Initializable Iterator: Initializable Iterator requires explicit initialization by calling Iterator.initializer operation before use, so that tf.placeholder can be used to pass in parameters when defining data sets.
max_value = tf.placeholder(tf.int64, shape=[])
dataset =
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next(), feed_dict={max_value: 10})
for i in range(10):
value =
assert i == value

  • reinitializable Iterator: it can be initialized by different dataset objects. For example, the training set is shuffle d, but the verification set is not processed. Usually, two dataset objects with the same structure are used in this case.
  • feedable Iterator: it can be combined with tf.placeholder through feed_dict mechanism to select which Iterator to select each time is called.

3.2 Iterator get_next traversal data

Iterator.get_ The next () method is the tf.Tensor object. Each time gets the value of the next element in the underlying dataset.

If the iterator reaches the end of the dataset, iterator.get is executed_ The next () operation will generate tf.errors.OutOfRangeError. After that, the iterator will be unavailable; If you need to continue using it, you must reinitialize it.
while True:
  except tf.errors.OutOfRangeError:

3.3 Iterator Save saveable_ from_ The iterator function creates a SaveableObject through the iterator, which can be used to save and restore the current state of the iterator (actually the entire input pipeline).

# Create saveable object from iterator.
saveable =
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
  if should_checkpoint:
# Restore the iterator state.
with tf.Session() as sess:
  saver.restore(sess, path_to_checkpoint)
Back to the top

4. Summary

This paper introduces the basic knowledge of creating different kinds of Dataset and Iterator objects. After being familiar with this data processing step, not only the reusability is strong, but also the efficiency can be doubled.

Tags: Python TensorFlow Deep Learning

Posted on Fri, 12 Nov 2021 03:17:47 -0500 by user___