Reading and writing of training data from small file to large file under PyTorch (comparison of various storage formats is attached)


  • Tensorflow has a special data reading module tfrecord, which can efficiently read the data used to train the neural network model and fully feed the GPU
  • Caffe uses lmdb to read data, which can also be read efficiently
  • PyTorch has a DataLoader to read data, but the speed is slow, especially when there are many small files
  • How to efficiently read data and make full use of GPU performance based on PyTorch has become a key problem?


  • Can I borrow tfrecord under tensorflow? Not necessarily
  • At present, some partners have realized it. For details, see: tfrecord
  • At the same time, on Kaggle, it is also implemented manually by great God. For details, see: PyTorch TFRecord-Loader
tfrecord write code:
import cv2
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from data_loader import TFRecordDataLoader
def read_txt(txt_path):
    with open(txt_path, 'r', encoding='utf-8') as f:
        data = f.readlines()
    data = list(map(lambda x: x.rstrip('\n'), data))
    return data
def bytes_to_numpy(image_bytes):
    image_np = np.frombuffer(image_bytes, dtype=np.uint8)
    image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
    return image_np2  
def list_record_features(tfrecords_path):
    """see tfrecords structure
        tfrecords_path (str): tfrecords route
        dict: structural information
    features = {}
    dataset =[str(tfrecords_path)])
    data = next(iter(dataset))
    example = tf.train.Example()
    example_bytes = data.numpy()
    for key, value in example.features.feature.items():
        kind = value.WhichOneof('kind')
        size = len(getattr(value, kind).value)
        if key in features:
            kind2, size2 = features[key]
            if kind != kind2:
                kind = None
            if size != size2:
                size = None
        features[key] = (kind, size)
    return features
class TFRecorder(object):
    def __init__(self) -> None:
        self.feature_dict = {
            'height': None,
            'width': None,
            'depth': None,
            'label': None,
            'image_raw': None
        self.AUTO =
    def image_to_feature(self, image_string, label):
        height, width, channel = tf.image.decode_image(image_string).shape
        self.feature_dict = {
            'height': self._int64_feature(height),
            'width': self._int64_feature(width),
            'depth': self._int64_feature(channel),
            'label': self._int64_feature(label),
            'image_raw': self._bytes_feature(image_string)
        return tf.train.Example(features=tf.train.Features(feature=self.feature_dict))
    def write(self, save_path, img_label_dict):
        with as writer:
            for file_name, label in tqdm(img_label_dict.items()):
                img_string = open(file_name, 'rb').read()
                feature = self.image_to_feature(img_string, label)
    def read(self, tfrecord_path):
        reader =
        dataset =,
        return dataset
    def _parse_image_function(self, example_proto):
        self.feature_dict = {
            'height':[], tf.int64),
            'width':[], tf.int64),
            'depth':[], tf.int64),
            'label':[], tf.int64),
            'image_raw':[], tf.string)
        example =,
        return example
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            # BytesList won't unpack a string from an EagerTensor.
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    def _float_feature(value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    def _int64_feature(value):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
if __name__ == '__main__':
    tfrecorder = TFRecorder()
    # val.txt stores the relative path of the image
    img_path = read_txt('dataset/val.txt')
    # Path(v) label of the image
    img_label_dict = {v: int(Path(v) for v in img_path}
    save_path = 'temp/val.tfrecords'
    tfrecorder.write(save_path, img_label_dict)
    dataset ='dataset/val.tfrecords')
    for v in dataset:
        img, label = v
    # View unknown tfrecords structure information
Read code based on tfrecord under PyTorch:
import cv2
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds


def bytes_to_numpy(image_bytes):
    image_np = np.frombuffer(image_bytes, dtype=np.uint8)
    image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
    return image_np2

def read_labeled_tfrecord(example_proto):
    feature_dict = {
        'height':[], tf.int64),
        'width':[], tf.int64),
        'depth':[], tf.int64),
        'label':[], tf.int64),
        'image_raw':[], tf.string)
    example =,
    img =['image_raw'], channels=3,
    img = tf.image.resize_with_crop_or_pad(img,
    return img, example['label']

def get_dataset(files, batch_size=16, repeat=False,
                cache=False, shuffle=False):
    ds =, num_parallel_reads=AUTO)
    if cache:
        ds = ds.cache()

    if repeat:
        ds = ds.repeat()

    if shuffle:
        ds = ds.shuffle(1024 * 2)
        opt =
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)

    ds =, num_parallel_calls=AUTO)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTO)
    return tfds.as_numpy(ds)

def count_data_items(file):
    num_ds =, num_parallel_reads=AUTO)
    num_ds =, num_parallel_calls=AUTO)
    num_ds = num_ds.repeat(1)
    num_ds = num_ds.batch(1)

    c = 0
    for _ in num_ds:
        c += 1
    del num_ds
    return c

class TFRecordDataLoader:
    def __init__(self, files, batch_size=32, cache=False, train=True,
                 repeat=False, shuffle=False, labeled=True,
        self.ds = get_dataset(

        if train:
            self.num_examples = count_data_items(files)

        self.batch_size = batch_size
        self.labeled = labeled
        self.return_image_ids = return_image_ids
        self._iterator = None

    def __iter__(self):
        if self._iterator is None:
            self._iterator = iter(self.ds)
        return self._iterator

    def _reset(self):
        self._iterator = iter(self.ds)

    def __next__(self):
        batch = next(self._iterator)
        return batch

    def __len__(self):
        n_batches = self.num_examples // self.batch_size
        if self.num_examples % self.batch_size == 0:
            return n_batches
            return n_batches + 1

# use
train_txt_path = 'dataset/minist/train.tfrecords'
train_dataloader = TFRecordDataLoader(train_txt_path,
for v in train_dataloader:


  • Throughout the major forums, when it comes to improving the reading speed of small files based on PyTorch, I have to talk about LMDB (lightning memory mapped database). I have also made some attempts, and the final conclusion will be given at the end
Write to LMDB
 import os
 import pickle
 from pathlib import Path
 import cv2
 import lmdb
 import numpy as np
 from PIL import Image
 from import DataLoader, Dataset
 from torchvision import transforms
 from tqdm import tqdm
 import utils
 class SimpleDataset(Dataset):
     def __init__(self, txt_path, transform=None) -> None:
         self.img_paths = utils.read_txt(txt_path)
         self.transform = transform
     def __getitem__(self, index: int):
         img_path = self.img_paths[index]
         label = int(Path(img_path)
             img =
             img = img.convert('RGB')
             img = cv2.imread(img_path)
             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
             img = Image.fromarray(img)
         if self.transform:
             img = self.transform(img)
         img = np.array(img)
         return img, label
     def __len__(self) -> int:
         return len(self.img_paths)
 class LMDB_Image:
     def __init__(self, image, label):
         # Dimensions of image for reconstruction - not really necessary
         # for this dataset, but some datasets may include images of
         # varying sizes
         self.channels = image.shape[2]
         self.size = image.shape[:2]
         self.image = image.tobytes()
         self.label = label
     def get_image(self):
         """ Returns the image as a numpy array. """
         image = np.frombuffer(self.image, dtype=np.uint8)
         return image.reshape(*self.size, self.channels)
 def data2lmdb(dpath, name="train", txt_path=None,
               write_frequency=10, num_workers=4):
     dataset = SimpleDataset(txt_path=txt_path)
     data_loader = DataLoader(dataset, num_workers=num_workers,
                              collate_fn=lambda x: x)
     lmdb_path = os.path.join(dpath, "%s.lmdb" % name)
     isdir = os.path.isdir(lmdb_path)
     print("Generate LMDB to %s" % lmdb_path)
     db =, subdir=isdir,
                    map_size=1099511627776,  # Unit byte
     txn = db.begin(write=True)
     for idx, data in enumerate(tqdm(data_loader)):
         image, label = data[0]
         temp = LMDB_Image(image, label)
         txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp))
         if idx % write_frequency == 0:
             print("[%d/%d]" % (idx, len(data_loader)))
             txn = db.begin(write=True)
     # finish iterating through dataset
     keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
     with db.begin(write=True) as txn:
         txn.put(b'__keys__', pickle.dumps(keys))
         txn.put(b'__len__', pickle.dumps(len(keys)))
     print("Flushing database ...")
 if __name__ == '__main__':
     save_dir = 'dataset/minist'
     data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')
class DatasetLMDB(Dataset):
    def __init__(self, db_path, transform=None):
        self.db_path = db_path
        self.env =,
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin() as txn:
            self.length = pickle.loads(txn.get(b'__len__'))
            self.keys = pickle.loads(txn.get(b'__keys__'))
        self.transform = transform
    def __getitem__(self, index):
        with self.env.begin() as txn:
            byteflow = txn.get(self.keys[index])
        IMAGE = pickle.loads(byteflow)
        img, label = IMAGE.get_image(), IMAGE.label
        return Image.fromarray(img).convert('RGB'), label
    def __len__(self):
        return self.length
# use
train_transforms = transforms.Compose([
    transforms.Resize((388, 270)),
        transforms.RandomPerspective(distortion_scale=0.6, p=0.5),
        transforms.ColorJitter(brightness=.5, hue=.3),
train_dataset = DatasetLMDB(train_txt_path, train_transforms)
train_dataloader = DataLoader(train_dataset,
# do other things

Binary large file

  • It is also an option to directly read the existing data set according to binary and store it in a large file of bins
Write bins
import cv2
import numpy as np
from tqdm import tqdm
def write_bin(save_bin_path, save_index_path, data):
    """Write an existing file based dataset as bin Large file
    Write to save_index_path Index position and label in the middle with\t division
        save_bin_path (str): preservation bin Location of
        save_index_path (str): preservation bin Index and corresponding label in
        data (str): Storage of image path and corresponding label list,
                    e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']]
    with open(save_bin_path, 'wb') as f_w, \
            open(save_index_path, 'w') as f_index:
        start_index = 0
        for img_path, label in tqdm(data):
            with open(img_path, 'rb') as f:
                img_bin =
            len_bin = len(img_bin)
            start_index += len_bin
def read_bin(bin_path, index_path):
    """read bin Large files and corresponding index labels txt
        bin_path (str): bin Large file storage path
        index_path (str): Index and label storage txt Path of
    with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index:
        index_lines = list(map(lambda x: x.strip(), f_index.readlines()))
        index_lines = list(map(lambda x: x.split('\t'), index_lines))
        for i, (start_index, length) in enumerate(index_lines):
            start_index = int(start_index)
            length = int(length.strip())
            # Navigate to the current pointer position to start_index
            # Read byte value of length
            img_bytes =
            img = np.frombuffer(img_bytes, dtype='uint8')
            img = cv2.imdecode(img, -1)  # -1: cv.IMREAD_UNCHANGED
        # Convert to PIL
            # img = Image.fromarray(img)
            # img = img.convert('RGB')
            # Save image
            # cv2.imwrite(f'temp/images/{i}.jpg', img)
Reading bins
from io import BytesIO
from PIL import Image
import cv2
import numpy as np
class SimpleDataset(Dataset):
    def __init__(self, txt_path, bin_path, transform=None) -> None:
        self.index_info = utils.read_txt(txt_path)
        self.index_info = list(map(lambda x: x.split('\t'), self.index_info))
        self.f_bin = open(bin_path, 'rb')
        self.transform = transform
    def __getitem__(self, index: int):
        start_index, length, label = list(map(int, self.index_info[index]))

        img_bytes =
        # Scheme I:
        img = np.frombuffer(img_bytes, dtype='uint8')
        img = cv2.imdecode(img, -1)
        if img is None:
            return self.__getitem__(random.randint(0, self.__len__() - 1))
        img = Image.fromarray(img)
        img = img.convert('RGB')
        # Scheme II:
            img =
            img = img.convert('RGB')
            return self.__getitem__(random.randint(0, self.__len__() - 1))
        if self.transform:
            img = self.transform(img)
        return img, label
    def __len__(self) -> int:
        return len(self.index_info)


  • It is also a good choice to use sqlite3 built in python as the storage format
Write to sqlite database
import sqlite3
from pathlib import Path
from tqdm import tqdm
def read_txt(txt_path):
    with open(txt_path, 'r', encoding='utf-8-sig') as f:
        data = list(map(lambda x: x.rstrip('\n'), f))
    return data
def img_to_bytes(img_path):
    with open(img_path, 'rb') as f:
        img_bytes =
        return img_bytes
class SQLiteWriter(object):
    def __init__(self, db_path):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
    def execute(self, sql, value=None):
        if value:
            self.cursor.execute(sql, value)
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
if __name__ == '__main__':
    dataset_dir = Path('datasets/minist')
    save_db_dir = dataset_dir / 'sqlite'
    save_db_path = str(save_db_dir / 'val.db')
    # Each line in val.txt: image path \ tcorresponding to text value e.g. xxxx.jpg\txxxxxx
    img_paths = read_txt(str(dataset_dir / 'val.txt'))
    with SQLiteWriter(save_db_path) as db_writer:
        # Create table
        table_name = 'minist'
        # Note that the fields in the table here should be defined according to their own data set
        # Specific database type,Refer to:
        # The data set involved in the example in the demo is the text recognition data set, the sample is the image, and the label is the corresponding text,
        # The data type of the following example field is the data type under python. You can write it to the table of the database by changing to the following corresponding data type
        # e.g. img_ path: str(xxxx.jpg), img_ Data: image data in bytes format, img_label: str(xxxxx)
        create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'
        # Insert data into the table with placeholders for the value part
        insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'
        for img_info in tqdm(img_paths):
            img_path, label = img_info.split('\t')
            img_full_path = str(dataset_dir / 'images' / img_path)
            img_data = img_to_bytes(img_full_path)
            db_writer.execute(insert_sql, (img_path, img_data, label))
Read database
class SimpleDataset(Dataset):
    def __init__(self, db_path, transform=None) -> None:
        self.db_path = db_path
        self.conn = None
        # Table name in database
        self.table_name = 'Synthetic_chinese_dataset'
        self.cursor.execute(f'select max(rowid) from {self.table_name}')
        self.nums = self.cursor.fetchall()[0][0]
        self.transform = transform
    def __getitem__(self, index: int):
        # query
        search_sql = f'select * from {self.table_name} where rowid=?'
        self.cursor.execute(search_sql, (index+1, ))
        img_path, img_bytes, label = self.cursor.fetchone()
        # Restore images and labels
        img =
        img = img.convert('RGB')
        img = scale_resize_pillow(img, (320, 32))
        if self.transform:
            img = self.transform(img)
        return img, label
    def __len__(self) -> int:
        return self.nums
    def establish_conn(self):
        if self.conn is None:
            self.conn = sqlite3.connect(self.db_path,
            self.cursor = self.conn.cursor()
        return self
    def close_conn(self):
        if self.conn is not None:
            del self.conn
            self.conn = None
        return self  
# --------------------------------------------------
train_dataset = SimpleDataset(train_db_path, train_transforms)
# ✧✧ for the use part, you need to manually close the database connection
train_dataloader = DataLoader(train_dataset,

Final conclusion

  • TFRecord

    • Before and after conversion, the data storage size remains unchanged, and GPU can be fully utilized
    • tfrecord cannot access other data enhancement methods (imgaug, opencv), and the data enhancement methods are very limited
  • LMDB

    • Before and after conversion, the data storage size will become very large (original 4.2G → 96g after conversion)
    • When PyTorch multi process reads data, the image cannot be restored to the original image. No solution has been found for the time being
    • Reading efficiency can make full use of GPU
  • Binary large file

    • The data storage size remains unchanged before and after conversion
    • Similarly, when PyTorch multi process reads, the image cannot be restored correctly, and no solution has been found for the time being
  • ✧ SQLite (recommended)

    • The data storage size remains unchanged before and after conversion
    • It can be read by multiple processes normally

reference material

Tags: Pytorch Deep Learning

Posted on Sat, 30 Oct 2021 21:51:01 -0400 by burn1337