Image classification dataset
Fashion MNIST dataset
%matplotlib inline import torch import torchvision # torchvision is a library dedicated to image processing in PyTorch from torch.utils import data #pytorch needs to read data from torchvision import transforms # torchvision.transforms is specially used for data related processing, such as normalization, reduction and scaling of PIL.Image, etc from d2l import torch as d2l d2l.use_svg_display()
- The data set is downloaded and read into memory through the built-in functions in the framework
- Understanding of ToTensor and ToPILImage
- Fashion MNIST consists of 10 categories of images, and each category consists of 6000 images in the training data set and 1000 images in the test data set. The test dataset is not used for training, but only for evaluating the performance of the model. The training set and test set contain 60000 and 10000 images respectively.
- The height and width of each input image are 28 pixels. The data set is composed of gray-scale images, and the number of channels is 1
- Gray level image is an image with only one sampling color per pixel, which is generally represented by two-dimensional matrix; Color (multi-channel) images adopt three-dimensional matrix
- If a pixel has three RGB colors to describe, it is three channels, through which the hue and color of the image can be changed; A channel is a gray image.
# The image data is transformed from PIL type to 32-bit floating-point format and tensor format through ToTensor instance # Divide by 255 so that the values of all pixels are between 0 and 1 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST( root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="../data", train=False, transform=trans, download=True)
- The 10 categories included in fashion MNIST are t-shirt, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot. The following function is used to convert between a numeric label index and its text name.
def get_fashion_mnist_labels(labels): """return Fashion-MNIST The text label of the dataset.""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
Read small batch
- To make it easier for us to read training sets and test sets, we use a built-in data loader instead of creating one from scratch. Recall that in each iteration, the data loader reads a small batch of data in the size of batch_size. We also randomly scrambled all samples in the training data iterator.
batch_size = 256 def get_dataloader_workers(): #@save """Four processes are used to read data.""" return 4 train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
Consolidate all components
- Define load_data_fashion_mnist function, which is used to obtain and read the data set. It returns the data iterator of the training set verification set. Optional parameters are used to resize the image to another shape
def load_data_fashion_mnist(batch_size, resize=None): """download Fashion-MNIST The dataset and then load it into memory.""" trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) # The compose class is mainly used for concatenating multiple picture transformations mnist_train = torchvision.datasets.FashionMNIST( root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="../data", train=False, transform=trans, download=True) return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
- Test load by specifying the resize parameter_ data_ fashion_ Image resizing function of MNIST function
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype) break