introduction
- Tensorflow has a special data reading module tfrecord, which can efficiently read the data used to train the neural network model and fully feed the GPU
- Caffe uses lmdb to read data, which can also be read efficiently
- PyTorch has a DataLoader to read data, but the speed is slow, especially when there are many small files
- How to efficiently read data and make full use of GPU performance based on PyTorch has become a key problem?
TFRecord
- Can I borrow tfrecord under tensorflow? Not necessarily
- At present, some partners have realized it. For details, see: tfrecord
- At the same time, on Kaggle, it is also implemented manually by great God. For details, see: PyTorch TFRecord-Loader
tfrecord write code:
```python import cv2 import numpy as np import tensorflow as tf from tqdm import tqdm from data_loader import TFRecordDataLoader def read_txt(txt_path): with open(txt_path, 'r', encoding='utf-8') as f: data = f.readlines() data = list(map(lambda x: x.rstrip('\n'), data)) return data def bytes_to_numpy(image_bytes): image_np = np.frombuffer(image_bytes, dtype=np.uint8) image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR) return image_np2 def list_record_features(tfrecords_path): """see tfrecords structure https://stackoverflow.com/questions/63562691/reading-a-tfrecord-file-where-features-that-were-used-to-encode-is-not-known Args: tfrecords_path (str): tfrecords route Returns: dict: structural information """ features = {} dataset = tf.data.TFRecordDataset([str(tfrecords_path)]) data = next(iter(dataset)) example = tf.train.Example() example_bytes = data.numpy() example.ParseFromString(example_bytes) for key, value in example.features.feature.items(): kind = value.WhichOneof('kind') size = len(getattr(value, kind).value) if key in features: kind2, size2 = features[key] if kind != kind2: kind = None if size != size2: size = None features[key] = (kind, size) return features class TFRecorder(object): def __init__(self) -> None: super().__init__() self.feature_dict = { 'height': None, 'width': None, 'depth': None, 'label': None, 'image_raw': None } self.AUTO = tf.data.experimental.AUTOTUNE def image_to_feature(self, image_string, label): height, width, channel = tf.image.decode_image(image_string).shape self.feature_dict = { 'height': self._int64_feature(height), 'width': self._int64_feature(width), 'depth': self._int64_feature(channel), 'label': self._int64_feature(label), 'image_raw': self._bytes_feature(image_string) } return tf.train.Example(features=tf.train.Features(feature=self.feature_dict)) def write(self, save_path, img_label_dict): with tf.io.TFRecordWriter(save_path) as writer: for file_name, label in tqdm(img_label_dict.items()): img_string = open(file_name, 'rb').read() feature = self.image_to_feature(img_string, label) writer.write(feature.SerializeToString()) def read(self, tfrecord_path): reader = tf.data.TFRecordDataset(tfrecord_path) dataset = reader.map(self._parse_image_function, num_parallel_calls=self.AUTO) return dataset def _parse_image_function(self, example_proto): self.feature_dict = { 'height': tf.io.FixedLenFeature([], tf.int64), 'width': tf.io.FixedLenFeature([], tf.int64), 'depth': tf.io.FixedLenFeature([], tf.int64), 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string) } example = tf.io.parse_single_example(example_proto, self.feature_dict) return example @staticmethod def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # BytesList won't unpack a string from an EagerTensor. value = value.numpy() return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @staticmethod def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) @staticmethod def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) if __name__ == '__main__': tfrecorder = TFRecorder() # val.txt stores the relative path of the image img_path = read_txt('dataset/val.txt') # Path(v).parent.name: label of the image img_label_dict = {v: int(Path(v).parent.name) for v in img_path} save_path = 'temp/val.tfrecords' tfrecorder.write(save_path, img_label_dict) dataset = tfrecorder.read('dataset/val.tfrecords') for v in dataset: img, label = v print('ok') # View unknown tfrecords structure information list_record_features('xxxx.tfrecords') ```
Read code based on tfrecord under PyTorch:
```python import cv2 import numpy as np import tensorflow as tf import tensorflow_datasets as tfds AUTO = tf.data.experimental.AUTOTUNE def bytes_to_numpy(image_bytes): image_np = np.frombuffer(image_bytes, dtype=np.uint8) image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR) return image_np2 def read_labeled_tfrecord(example_proto): feature_dict = { 'height': tf.io.FixedLenFeature([], tf.int64), 'width': tf.io.FixedLenFeature([], tf.int64), 'depth': tf.io.FixedLenFeature([], tf.int64), 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string) } example = tf.io.parse_single_example(example_proto, feature_dict) img = tf.io.decode_image(example['image_raw'], channels=3, expand_animations=False) img = tf.image.resize_with_crop_or_pad(img, target_height=388, target_width=270) return img, example['label'] def get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False): ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO) if cache: ds = ds.cache() if repeat: ds = ds.repeat() if shuffle: ds = ds.shuffle(1024 * 2) opt = tf.data.Options() opt.experimental_deterministic = False ds = ds.with_options(opt) ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO) ds = ds.batch(batch_size) ds = ds.prefetch(AUTO) return tfds.as_numpy(ds) def count_data_items(file): num_ds = tf.data.TFRecordDataset(file, num_parallel_reads=AUTO) num_ds = num_ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO) num_ds = num_ds.repeat(1) num_ds = num_ds.batch(1) c = 0 for _ in num_ds: c += 1 del num_ds return c class TFRecordDataLoader: def __init__(self, files, batch_size=32, cache=False, train=True, repeat=False, shuffle=False, labeled=True, return_image_ids=True): self.ds = get_dataset( files, batch_size=batch_size, cache=cache, repeat=repeat, shuffle=shuffle,) if train: self.num_examples = count_data_items(files) self.batch_size = batch_size self.labeled = labeled self.return_image_ids = return_image_ids self._iterator = None def __iter__(self): if self._iterator is None: self._iterator = iter(self.ds) else: self._reset() return self._iterator def _reset(self): self._iterator = iter(self.ds) def __next__(self): batch = next(self._iterator) return batch def __len__(self): n_batches = self.num_examples // self.batch_size if self.num_examples % self.batch_size == 0: return n_batches else: return n_batches + 1 # use train_txt_path = 'dataset/minist/train.tfrecords' train_dataloader = TFRecordDataLoader(train_txt_path, batch_size=batch_size, shuffle=True) for v in train_dataloader: pass ```
LMDB
- Throughout the major forums, when it comes to improving the reading speed of small files based on PyTorch, I have to talk about LMDB (lightning memory mapped database). I have also made some attempts, and the final conclusion will be given at the end
Write to LMDB
import os import pickle from pathlib import Path import cv2 import lmdb import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms from tqdm import tqdm import utils class SimpleDataset(Dataset): def __init__(self, txt_path, transform=None) -> None: self.img_paths = utils.read_txt(txt_path) self.transform = transform def __getitem__(self, index: int): img_path = self.img_paths[index] label = int(Path(img_path).parent.name) try: img = Image.open(img_path) img = img.convert('RGB') except: img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) if self.transform: img = self.transform(img) img = np.array(img) return img, label def __len__(self) -> int: return len(self.img_paths) class LMDB_Image: def __init__(self, image, label): # Dimensions of image for reconstruction - not really necessary # for this dataset, but some datasets may include images of # varying sizes self.channels = image.shape[2] self.size = image.shape[:2] self.image = image.tobytes() self.label = label def get_image(self): """ Returns the image as a numpy array. """ image = np.frombuffer(self.image, dtype=np.uint8) return image.reshape(*self.size, self.channels) def data2lmdb(dpath, name="train", txt_path=None, write_frequency=10, num_workers=4): dataset = SimpleDataset(txt_path=txt_path) data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x) lmdb_path = os.path.join(dpath, "%s.lmdb" % name) isdir = os.path.isdir(lmdb_path) print("Generate LMDB to %s" % lmdb_path) db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776, # Unit byte readonly=False, meminit=False, map_async=True) txn = db.begin(write=True) for idx, data in enumerate(tqdm(data_loader)): image, label = data[0] temp = LMDB_Image(image, label) txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp)) if idx % write_frequency == 0: print("[%d/%d]" % (idx, len(data_loader))) txn.commit() txn = db.begin(write=True) # finish iterating through dataset txn.commit() keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] with db.begin(write=True) as txn: txn.put(b'__keys__', pickle.dumps(keys)) txn.put(b'__len__', pickle.dumps(len(keys))) print("Flushing database ...") db.sync() db.close() if __name__ == '__main__': save_dir = 'dataset/minist' data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')
Read LMDB
class DatasetLMDB(Dataset): def __init__(self, db_path, transform=None): self.db_path = db_path self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path), readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin() as txn: self.length = pickle.loads(txn.get(b'__len__')) self.keys = pickle.loads(txn.get(b'__keys__')) self.transform = transform def __getitem__(self, index): with self.env.begin() as txn: byteflow = txn.get(self.keys[index]) IMAGE = pickle.loads(byteflow) img, label = IMAGE.get_image(), IMAGE.label return Image.fromarray(img).convert('RGB'), label def __len__(self): return self.length # use train_transforms = transforms.Compose([ transforms.Resize((388, 270)), transforms.RandomChoice([ transforms.RandomRotation(10), transforms.RandomHorizontalFlip(0.5), transforms.RandomGrayscale(p=0.3), transforms.RandomPerspective(distortion_scale=0.6, p=0.5), transforms.ColorJitter(brightness=.5, hue=.3), ]), transforms.ToTensor(), normalize, transforms.RandomErasing(), ]) train_dataset = DatasetLMDB(train_txt_path, train_transforms) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_worker, pin_memory=True) # do other things
Binary large file
- It is also an option to directly read the existing data set according to binary and store it in a large file of bins
Write bins
import cv2 import numpy as np from tqdm import tqdm def write_bin(save_bin_path, save_index_path, data): """Write an existing file based dataset as bin Large file Write to save_index_path Index position and label in the middle with\t division Args: save_bin_path (str): preservation bin Location of save_index_path (str): preservation bin Index and corresponding label in data (str): Storage of image path and corresponding label list, e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']] """ with open(save_bin_path, 'wb') as f_w, \ open(save_index_path, 'w') as f_index: start_index = 0 for img_path, label in tqdm(data): with open(img_path, 'rb') as f: img_bin = f.read() f_w.write(img_bin) len_bin = len(img_bin) f_index.write(f'{start_index}\t{len_bin}\t{label}\n') start_index += len_bin def read_bin(bin_path, index_path): """read bin Large files and corresponding index labels txt Args: bin_path (str): bin Large file storage path index_path (str): Index and label storage txt Path of """ with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index: index_lines = list(map(lambda x: x.strip(), f_index.readlines())) index_lines = list(map(lambda x: x.split('\t'), index_lines)) for i, (start_index, length) in enumerate(index_lines): start_index = int(start_index) length = int(length.strip()) # Navigate to the current pointer position to start_index f_bin.seek(start_index) # Read byte value of length img_bytes = f_bin.read(length) img = np.frombuffer(img_bytes, dtype='uint8') img = cv2.imdecode(img, -1) # -1: cv.IMREAD_UNCHANGED # Convert to PIL # img = Image.fromarray(img) # img = img.convert('RGB') # Save image # cv2.imwrite(f'temp/images/{i}.jpg', img)
Reading bins
from io import BytesIO from PIL import Image import cv2 import numpy as np class SimpleDataset(Dataset): def __init__(self, txt_path, bin_path, transform=None) -> None: self.index_info = utils.read_txt(txt_path) self.index_info = list(map(lambda x: x.split('\t'), self.index_info)) self.f_bin = open(bin_path, 'rb') self.transform = transform def __getitem__(self, index: int): start_index, length, label = list(map(int, self.index_info[index])) print(start_index) self.f_bin.seek(start_index) img_bytes = self.f_bin.read(length) # Scheme I: img = np.frombuffer(img_bytes, dtype='uint8') img = cv2.imdecode(img, -1) if img is None: return self.__getitem__(random.randint(0, self.__len__() - 1)) img = Image.fromarray(img) img = img.convert('RGB') # Scheme II: try: img = Image.open(BytesIO(img_bytes)) img = img.convert('RGB') except: return self.__getitem__(random.randint(0, self.__len__() - 1)) if self.transform: img = self.transform(img) return img, label def __len__(self) -> int: return len(self.index_info)
Sqlite
- It is also a good choice to use sqlite3 built in python as the storage format
Write to sqlite database
import sqlite3 from pathlib import Path from tqdm import tqdm def read_txt(txt_path): with open(txt_path, 'r', encoding='utf-8-sig') as f: data = list(map(lambda x: x.rstrip('\n'), f)) return data def img_to_bytes(img_path): with open(img_path, 'rb') as f: img_bytes = f.read() return img_bytes class SQLiteWriter(object): def __init__(self, db_path): self.conn = sqlite3.connect(db_path) self.cursor = self.conn.cursor() def execute(self, sql, value=None): if value: self.cursor.execute(sql, value) else: self.cursor.execute(sql) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.cursor.close() self.conn.commit() self.conn.close() if __name__ == '__main__': dataset_dir = Path('datasets/minist') save_db_dir = dataset_dir / 'sqlite' save_db_path = str(save_db_dir / 'val.db') # Each line in val.txt: image path \ tcorresponding to text value e.g. xxxx.jpg\txxxxxx img_paths = read_txt(str(dataset_dir / 'val.txt')) with SQLiteWriter(save_db_path) as db_writer: # Create table table_name = 'minist' # Note that the fields in the table here should be defined according to their own data set # Specific database type,Refer to: https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types # The data set involved in the example in the demo is the text recognition data set, the sample is the image, and the label is the corresponding text, # The data type of the following example field is the data type under python. You can write it to the table of the database by changing to the following corresponding data type # e.g. img_ path: str(xxxx.jpg), img_ Data: image data in bytes format, img_label: str(xxxxx) create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)' db_writer.execute(create_table_sql) # Insert data into the table with placeholders for the value part insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)' for img_info in tqdm(img_paths): img_path, label = img_info.split('\t') img_full_path = str(dataset_dir / 'images' / img_path) img_data = img_to_bytes(img_full_path) db_writer.execute(insert_sql, (img_path, img_data, label))
Read database
class SimpleDataset(Dataset): def __init__(self, db_path, transform=None) -> None: self.db_path = db_path self.conn = None self.establish_conn() # Table name in database self.table_name = 'Synthetic_chinese_dataset' self.cursor.execute(f'select max(rowid) from {self.table_name}') self.nums = self.cursor.fetchall()[0][0] self.transform = transform def __getitem__(self, index: int): self.establish_conn() # query search_sql = f'select * from {self.table_name} where rowid=?' self.cursor.execute(search_sql, (index+1, )) img_path, img_bytes, label = self.cursor.fetchone() # Restore images and labels img = Image.open(BytesIO(img_bytes)) img = img.convert('RGB') img = scale_resize_pillow(img, (320, 32)) if self.transform: img = self.transform(img) return img, label def __len__(self) -> int: return self.nums def establish_conn(self): if self.conn is None: self.conn = sqlite3.connect(self.db_path, check_same_thread=False, cached_statements=1024) self.cursor = self.conn.cursor() return self def close_conn(self): if self.conn is not None: self.cursor.close() self.conn.close() del self.conn self.conn = None return self # -------------------------------------------------- train_dataset = SimpleDataset(train_db_path, train_transforms) # ✧✧ for the use part, you need to manually close the database connection train_dataset.close_conn() train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_worker, pin_memory=True, sampler=train_sampler)
Final conclusion
-
TFRecord
- Before and after conversion, the data storage size remains unchanged, and GPU can be fully utilized
- tfrecord cannot access other data enhancement methods (imgaug, opencv), and the data enhancement methods are very limited
-
LMDB
- Before and after conversion, the data storage size will become very large (original 4.2G → 96g after conversion)
- When PyTorch multi process reads data, the image cannot be restored to the original image. No solution has been found for the time being
- Reading efficiency can make full use of GPU
-
Binary large file
- The data storage size remains unchanged before and after conversion
- Similarly, when PyTorch multi process reads, the image cannot be restored correctly, and no solution has been found for the time being
-
✧ SQLite (recommended)
- The data storage size remains unchanged before and after conversion
- It can be read by multiple processes normally