OCR character recognition with Transformer!

Datawhale dry

Authors: an Sheng, Yuan Mingkun, members of Datawhale

In the CV field, what else can transformer do besides classification? This paper will use a word recognition task data set to explain how to use transformer to realize a simple OCR word recognition task, and understand how transformer is applied to more complex CV tasks besides classification. The full text is divided into four parts:

1, Data set introduction and acquisition

2, Data analysis and relationship construction

3, How to introduce transformer into OCR

4, Explanation of training framework code

Note: This article focuses on how to design model and training architecture to solve OCR tasks. The article contains complete practice and long code. It is recommended to collect. If you are not familiar with transformer, you can click here Review.

The whole character recognition task mainly includes the following files: - analysis_recognition_dataset.py (dataset analysis script) - ocr_by_transformer.py (OCR task training script) -transformer.py (transformer model file) - train_utils.py (training related auxiliary functions, loss, optimizer, etc.) Where ocr_by_transformer.py is the main training script, which relies on train_utils.py and transformer.py build transformer to complete the training of character recognition model

1, Data set introduction and acquisition

The data set used in this paper is based on Task 4.3: Word Recognition in icdar2015 incident scene text, which is a famous text recognition data set in natural scenes. This time, it is used for word recognition task. We remove some pictures to simplify the difficulty of this experiment. Therefore, the data set in this paper is slightly different from the original data set.

In order to better carry out data sharing and version control, we choose to call the data set online and store the simplified data set on a special data sharing platform. The data open source address is: https://gas.graviti.cn/dataset/datawhale/ICDAR2015 , relevant questions can be directly communicated in the dataset discussion area.

The data set contains many text areas in natural scene images. The training set contains 4326 images and the test set contains 1992 images. They are cut out from the original large image according to the bounding box of the text area, and the text in the image is basically in the center of the picture.

Images in the dataset are similar to the following styles:

word_104.png, "Optical"

The data itself is displayed in an image, and the corresponding label is stored in CLASSIFICATION. If the label is obtained in the later code, a list containing all characters will be directly obtained, which is also a storage selection for the convenience of label ease of use.

The following is a brief introduction to the rapid use of data sets:

  • Download and install tensorbay locally
pip3 install tensorbay
  • Open this article dataset link: https://gas.graviti.cn/dataset/datawhale/ICDAR2015
  • fork the dataset into your account
  • Click the developer tool -- > accessKey -- > at the top of the web page to create a new accessKey -- > and copy the Key
from tensorbay import GAS
from tensorbay.dataset import Dataset

# GAS voucher
KEY = 'Accesskey-***************80a'  # Add your own AccessKey
gas = GAS(KEY)

# Get dataset
dataset = Dataset("ICDAR2015", gas)
# dataset.enable_cache('./data')  # Open this statement and select to create a local cache of data

# Training set and verification set
train_segment = dataset["train"]valid_segment = dataset['valid']

# Data and labels
for data in train_segment:
    # image data
    img = data.open()
    # Image label
    label = data.label.classification.category

The image and label obtained through the above code are as follows:



['C', 'A', 'U', 'T', 'I', 'O', 'N']

Through the above simple code, you can quickly obtain image data and labels, but each time the program runs, it will automatically go to the platform to download data, so it takes a long time. It is recommended to open the local cache and download it for multiple uses at one time. When the data is no longer used, you can delete the data.

2, Data analysis and relationship construction

Before starting the experiment, we first make a simple analysis of the data. Only if we have a sufficient understanding of the characteristics of the data can we better build a baseline and avoid detours in training.

Run the following code to complete the simple analysis of the data set with one click:

python analysis_recognition_dataset.py

Specifically, the work done by this script includes: tag character statistics (which characters are there and how many times each character appears), longest tag length statistics, image size analysis, etc. on the data, and constructing the mapping relationship file lbl2id of character tags_ map.txt.

Let's look at the code a little bit:

Note: the open source address of this Code:


First, complete the preparation, import the required libraries, and set the path of relevant directories or files

import os
from PIL import Image
import tqdm

from tensorbay import GAS
from tensorbay.dataset import Dataset

# GAS voucher
KEY = 'Accesskey-************************480a'  # Add your own AccessKey
gas = GAS(KEY)
# The dataset is fetched and cached locally
dataset = Dataset("ICDAR2015", gas)
dataset.enable_cache('./data')  # Data cache address

# Get training set and verification set
train_segment = dataset["train"]
valid_segment = dataset['valid']

# The intermediate file storage path stores the mapping relationship between label characters and their IDs
base_data_dir = './'
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')

2.1 statistics of the longest characters of the label

First, count the number of characters contained in the longest label in the data set. Here, count the longest label in both the training set and the verification set to get the characters contained in the longest label.

def statistics_max_len_label(segment):
    Longest in statistics tag label Number of characters contained
    max_len = -1
    for data in segment:
        lbl_str = data.label.classification.category  # Get label
        lbl_len = len(lbl_str)
        max_len = max_len if max_len > lbl_len else lbl_len
    return max_len

train_max_label_len = statistics_max_len_label(train_segment)  # Longest training set label
valid_max_label_len = statistics_max_len_label(valid_segment)  # Verification set longest label
max_label_len = max(train_max_label_len, valid_max_label_len)  # The longest label in the whole dataset
print(f"The data set contains the most characters label Count Reg{max_label_len}")

The longest label in the dataset contains 21 characters, which will provide a reference for the setting of time step length when the transformer model is built later.

2.2 statistics of characters contained in labels

The following code is used to view all characters that have appeared in the dataset:

def statistics_label_cnt(segment, lbl_cnt_map):
    Statistics tag file label What characters are included and how many times they appear
    lbl_cnt_map : A dictionary that records the number of occurrences of characters in a label
    for data in segment:
        lbl_str = data.label.classification.category  # Get label
        for lbl in lbl_str:
                if lbl not in lbl_cnt_map.keys():
                    lbl_cnt_map[lbl] = 1
                    lbl_cnt_map[lbl] += 1

lbl_cnt_map = dict()  # A dictionary used to store the number of occurrences of characters
statistics_label_cnt(train_segment, lbl_cnt_map)  # Statistics of character occurrences in training set
print("Training concentration label Characters appearing in:")
statistics_label_cnt(valid_segment, lbl_cnt_map)  # Statistics of character occurrences in training set and verification set
print("Training set+Validation set label Characters appearing in:")

The output result is:

Training concentration label Characters appearing in:
{'C': 593, 'A': 1189, 'U': 319, 'T': 896, 'I': 861, 'O': 965, 'N': 785, 'D': 383, 'W': 179, 'M': 367, 'E': 1423, 'X': 110, '$': 46, '2': 121, '4': 44, 'L': 745, 'F': 259, 'P': 389, 'R': 836, 'S': 1164, 'a': 843, 'v': 123, 'e': 1057, 'G': 345, "'": 51, 'r': 655, 'k': 96, 's': 557, 'i': 651, 'c': 318, 'V': 158, 'H': 391, '3': 50, '.': 95, '"': 8, '-': 68, ',': 19, 'Y': 229, 't': 563, 'y': 161, 'B': 332, 'u': 293, 'x': 27, 'n': 605, 'g': 171, 'o': 659, 'l': 408, 'd': 258, 'b': 88, 'p': 197, 'K': 163, 'J': 72, '5': 80, '0': 203, '1': 186, 'h': 299, '!': 51, ':': 19, 'f': 133, 'm': 202, '9': 66, '7': 45, 'j': 15, 'z': 12, '´': 3, 'Q': 19, 'Z': 29, '&': 9, ' ': 50, '8': 47, '/': 24, '#': 16, 'w': 97, '?': 5, '6': 40, '[': 2, ']': 2, 'É': 1, 'q': 3, ';': 3, '@': 4, '%': 28, '=': 1, '(': 6, ')': 5, '+': 1}
Training set+Validation set label Characters appearing in:
{'C': 893, 'A': 1827, 'U': 467, 'T': 1315, 'I': 1241, 'O': 1440, 'N': 1158, 'D': 548, 'W': 288, 'M': 536, 'E': 2215, 'X': 181, '$': 57, '2': 141, '4': 53, 'L': 1120, 'F': 402, 'P': 582, 'R': 1262, 'S': 1752, 'a': 1200, 'v': 169, 'e': 1536, 'G': 521, "'": 70, 'r': 935, 'k': 137, 's': 793, 'i': 924, 'c': 442, 'V': 224, 'H': 593, '3': 69, '.': 132, '"': 8, '-': 87, ',': 25, 'Y': 341, 't': 829, 'y': 231, 'B': 469, 'u': 415, 'x': 38, 'n': 880, 'g': 260, 'o': 955, 'l': 555, 'd': 368, 'b': 129, 'p': 317, 'K': 253, 'J': 100, '5': 105, '0': 258, '1': 231, 'h': 417, '!': 65, ':': 24, 'f': 203, 'm': 278, '9': 76, '7': 62, 'j': 19, 'z': 14, '´': 3, 'Q': 28, 'Z': 36, '&': 15, ' ': 82, '8': 58, '/': 29, '#': 24, 'w': 136, '?': 7, '6': 46, '[': 2, ']': 2, 'É': 2, 'q': 3, ';': 3, '@': 9, '%': 42, '=': 1, '(': 7, ')': 5, '+': 2, 'é': 1}

In the code above, lbl_cnt_map is a statistical Dictionary of the number of occurrences of characters. It will also be used to establish the mapping relationship between characters and their IDs.

From the statistical results of the data set, the test set contains characters that have not appeared in the training set. For example, the test set contains a 'é' that has not appeared in the training set. The number of such cases is small, so it should not be a problem, so no additional processing is performed on the data set here (but it is necessary to consciously check whether there is diff in the training set and test set).

2.3 construction of mapping dictionary between char and id

In the OCR task of this paper, it is necessary to predict each character in the picture. In order to achieve this purpose, we first need to establish a mapping relationship between a character and its id, and convert the text information into digital information that can be read by the model. This step is similar to establishing a corpus in NLP.

When building a mapping relationship, in addition to recording the characters appearing in all label files, three special characters need to be initialized to represent a sentence start character, sentence end character and padding identifier (related introduction stamp) here ). The following explanation of dataset construction will also be mentioned again.

After the script runs, the mapping relationship of all characters will be saved in lbl2id_ In the map.txt file.

# Construct the mapping between characters -- id in label
print("structure label Chinese character--id Mapping between:")

lbl2id_map = dict()
# Initialize three special characters
lbl2id_map['☯'] = 0    # padding identifier
lbl2id_map['■'] = 1    # Sentence starter
lbl2id_map['□'] = 2    # Sentence Terminator
# Generate id mapping relationships for the remaining characters
cur_id = 3
for lbl in lbl_cnt_map.keys():
    lbl2id_map[lbl] = cur_id
    cur_id += 1
# Save the mapping between characters -- id to txt file
with open(lbl2id_map_path, 'w', encoding='utf-8') as writer:  # The parameter encoding is optional. Some devices do not default to utf-8
    for lbl in lbl2id_map.keys():
        cur_id = lbl2id_map[lbl]
        print(lbl, cur_id)
        line = lbl + '\t' + str(cur_id) + '\n'

Mapping between constructed character id:

☯ 0
■ 1
□ 2
C 3
A 4
= 85
( 86
) 87
+ 88
é 89

In addition, analysis_ recognition_ The dataset.py file also contains a function to establish a relationship mapping dictionary. You can build character to id and id to character mapping dictionaries by reading the file containing mapping relationship txt. This serves the subsequent transformer training process to facilitate the fast conversion of character relationships.

def load_lbl2id_map(lbl2id_map_path):
    Read character-id Of mapping relationship records txt File and return lbl->id and id->lbl Mapping dictionary
    lbl2id_map_path : character-id Of mapping relationship records txt File path

    lbl2id_map = dict()
    id2lbl_map = dict()
    with open(lbl2id_map_path, 'r') as reader:
        for line in reader:
            items = line.rstrip().split('\t')
            label = items[0]
            cur_id = int(items[1])
            lbl2id_map[label] = cur_id
            id2lbl_map[cur_id] = label
    return lbl2id_map, id2lbl_map

2.4 data set image size analysis

When carrying out tasks such as image classification and detection, we often check the size distribution of the image, and then determine the appropriate image preprocessing method. For example, when carrying out target detection, we will count the image size and the size of the bounding box, analyze the aspect ratio, and then select the appropriate image clipping strategy and the appropriate initial anchor strategy.

Therefore, the characteristics of the data are understood by analyzing the information of image width, height and aspect ratio, so as to provide reference for the formulation of subsequent experimental strategies.

def read_gas_image(data):
    with data.open() as fp:
        image = Image.open(fp)
    return image

# Analyze dataset picture size
print("Analyze dataset picture size:")

# Initialization parameters
min_h = 1e10
min_w = 1e10
max_h = -1
max_w = -1
min_ratio = 1e10
max_ratio = 0
# Traverse the dataset to calculate size information
for data in tqdm.tqdm(train_segment):
    img = read_gas_image(data)  # Read picture
    w, h = img.size  # Extract image width and height information
    ratio = w / h  # Aspect ratio
    min_h = min_h if min_h <= h else h  # Minimum picture height
    max_h = max_h if max_h >= h else h  # Maximum picture height
    min_w = min_w if min_w <= w else w  # Minimum picture width
    max_w = max_w if max_w >= w else w  # Maximum picture width
    min_ratio = min_ratio if min_ratio <= ratio else ratio  # Minimum aspect ratio
    max_ratio = max_ratio if max_ratio >= ratio else ratio  # Maximum aspect ratio
# Output information
print('min_h:', min_h)
print('max_h:', max_h)
print('min_w:', min_w)
print('max_w:', max_w)
print('min_ratio:', min_ratio)
print('max_ratio:', max_ratio)

The statistical results related to the image size of the dataset are as follows:

min_h: 9
max_h: 295
min_w: 16
max_w: 628
min_ratio: 0.6666666666666666
max_ratio: 8.619047619047619

From the above results, it can be seen that most of the pictures are lying down long strips, and the maximum aspect ratio is > 8. It can be seen that there are extremely slender pictures.

The above is some simple analysis of the data set, and the char2id mapping file for training is prepared. Here is the play. Let's see how we introduce transfomer to complete the CV task of OCR word recognition.

3, How to introduce transformer into OCR

Many algorithms themselves are not difficult, but how to think and define the problem and transform it into known solutions. So before looking at the code, let's talk about why transformer can solve the OCR problem and what's the motivation?

First of all, we know that transformer is widely used in the NLP field and can solve the problem of sequence to sequence class such as machine translation, as shown in the following figure:

The OCR recognition task is shown in the figure below. We want to recognize the figure below as "Share", which can also be regarded as a sequence to sequence task in essence, but the input sequence information is represented in the form of pictures.

Therefore, if the OCR problem is regarded as a sequence to sequence prediction problem, it seems to be a very natural and smooth idea to use transformer to solve the OCR problem. The remaining problem is how to construct the image information into the input in the form of word embedding.

Back to our task, since the pictures to be predicted are long strips and the characters are basically arranged horizontally, we integrate the feature map along the horizontal direction, and each embedding obtained can be regarded as the feature of a slice in the vertical direction of the picture. We hand over such feature sequence to the transformer and use its powerful attention ability to complete the prediction.

Therefore, based on the above analysis, we define the pipeline of the model framework as shown in the following figure:

By observing the above figure, it can be found that the whole pipeline is basically the same as the process of training machine translation with transformer, and the difference is mainly due to the process of extracting image features and obtaining input embedding with the help of a CNN network as the backbone.

The design of the input embedding of the transformer is the focus of this paper and the key to the work of the whole algorithm. The following text will explain the relevant details shown in the above schematic diagram in combination with the code.

4, Explanation of training framework code

The related codes of training framework are implemented in ocr_by_transformer.py file

Let's start to explain the code step by step, mainly including the following parts:

  • Build dataset → image preprocessing, label processing, etc;
  • Model construction → backbone + transformer;
  • model training
  • Reasoning → greedy decoding

Let's look at it step by step

4.1 preparation

First, import the library to be used later

import os
import time
import copy
from PIL import Image

# Online dataset related package
from tensorbay import GAS
from tensorbay.dataset import Dataset

# torch related package
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms

# Import tool class package
from analysis_recognition_dataset import load_lbl2id_map, statistics_max_len_label
from transformer import *
from train_utils import *

Then set some basic parameters

device = torch.device('cuda')  # 'cpu' or 'cuda'
nrof_epochs = 1500  # Number of iterations, 1500, revised according to requirements
batch_size = 64     # Batch size, 64, corrected as required
model_save_path = './log/ex1_ocr_model.pth'

Obtain image data online and read the mapping dictionary between characters and their IDs in the image label. Subsequent Dataset creation needs to be used.

# GAS voucher
KEY = 'Accesskey-fd26cc098604c68a99d3bf7f87cd480a'
gas = GAS(KEY)
# Get datasets Online
dataset_online = Dataset("ICDAR2015", gas)
dataset_online.enable_cache('./data')  # Turn on local cache

# Get training set and verification set
train_data_online = dataset_online["train"]
valid_data_online = dataset_online['valid']

# Read the label ID mapping record file
lbl2id_map_path = os.path.join('./', 'lbl2id_map.txt')
lbl2id_map, id2lbl_map = load_lbl2id_map(lbl2id_map_path)

# Statistics on the number of characters in all label s in the dataset that contain the most characters. The data set construction gt (ground truth) information needs to be used
train_max_label_len = statistics_max_len_label(train_data_online)
valid_max_label_len = statistics_max_len_label(valid_data_online)
# The case with the largest number of characters in the dataset is used as the sequence of gt_ len
sequence_len = max(train_max_label_len, valid_max_label_len)  

4.2 Dataset construction

Let's introduce the related contents of Dataset construction. First, it's reasonable to think about how to preprocess images. Picture preprocessing scheme

Suppose the picture size is [batch_size, 3, H_i, W_i]

The size of characteristic graph after network is [batch_size, c_f, h_f, w_f]

Based on the previous analysis of the data set, the pictures are basically horizontal strips, and the image content is words composed of horizontally arranged characters. Then, the position of the same vertical slice in the picture space basically has only one character, so the vertical resolution does not need to be large, so take H_f=1; The horizontal resolution needs to be larger. We need different embedding to encode the characteristics of different characters in the horizontal direction.

Here, we use the most classic resnet18 network as the backbone. Because its lower sampling multiple is 32 and the channel number of the last layer characteristic graph is 512, then: H_i = H_f * 32 = 32 C_f = 512

So how to determine the width of the input picture? Here are two schemes, as shown in the figure below:

Method 1: set a fixed size, resize the image with its aspect ratio, and pad the empty area on the right;

Method 2: directly force the original image to resize to a preset fixed size.

Note: you might as well think about it first. Which scheme do you think is better?

The author chose method 1, because the aspect ratio of the picture is roughly proportional to the number of characters of words in the picture. If the aspect ratio of the original picture is maintained during preprocessing, the range of each pixel on the feature map corresponding to the character area on the original image is basically stable, which may have a better prediction effect.

Here's another detail. If you look at the above figure, you will find that each area with width: height = 1:1 is basically distributed with 2-3 characters. Therefore, in actual operation, we did not strictly keep the aspect ratio unchanged, but increased the aspect ratio by 3 times, that is, first lengthen the width of the original image to 3 times, then maintain the aspect ratio and resize the height to 32.

Note: it is suggested to stop and think again. Why is this detail just now?

The purpose of this is to make each character in the picture have at least one pixel on the feature map corresponding to it, rather than one pixel on the wide dimension of the feature map, and encode the information of multiple characters in the original image at the same time. In my opinion, it will bring unnecessary difficulties to the prediction of transformer (just personal point of view, welcome discussion).

The resize scheme is determined, W_ What is the specific setting of I? In combination with the two important indicators when we analyze the data set, the longest character number in the data set label is 21 and the longest aspect ratio is 8.6. We set the final aspect ratio to 24:1. Therefore, we summarize the settings of various parameters:

H_i = H_f * 32 = 32 W_i = 24 * H_i = 768 C_f = 512, H_f = 1, W_f = 24

Relevant code implementation:

# ----------------
# Image preprocessing
# ----------------
# load image
with img_data.open() as fp:
    img = Image.open(fp).convert('RGB')

# Zoom the picture approximately equally
# Reduce the height to 32 and the width to equal scale, but divide by 32
w, h = img.size
ratio = round((w / h) * 3)   # Lengthen the width three times and round it
if ratio == 0:
    ratio = 1 
if ratio > self.max_ratio:
    ratio = self.max_ratio
h_new = 32
w_new = h_new * ratio
img_resize = img.resize((w_new, h_new), Image.BILINEAR)

# padding the right half of the picture so that the width / height ratio is fixed = self.max_ratio
img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
img_padd.paste(img_resize, (0, 0)) 

Image augmentation

Image enlargement is not the key point. Here, in addition to the above resize scheme, we only perform conventional random color transformation and normalization on the image.

Complete code

The complete code for building Dataset is as follows:

class Recognition_Dataset(object):

    def __init__(self, segment, lbl2id_map, sequence_len, max_ratio, pad=0):         self.data = segment
        self.lbl2id_map = lbl2id_map
        self.pad = pad   # The id of the padding identifier. The default is 0
        self.sequence_len = sequence_len    # Sequence length
        self.max_ratio = max_ratio * 3      # Lengthen the width by 3 times

        # Define random color transformations
        self.color_trans = transforms.ColorJitter(0.1, 0.1, 0.1)
        # Define Normalize
        self.trans_Normalize = transforms.Compose([
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),

    def __getitem__(self, index):
        Get corresponding index Images and ground truth label,And data enhancement as appropriate
        img_data = self.data[index]
        lbl_str = img_data.label.classification.category  # Image label

        # ----------------
        # Image preprocessing
        # ----------------
        # load image
        with img_data.open() as fp:
            img = Image.open(fp).convert('RGB')

        # Zoom the picture approximately equally
        # Reduce the height to 32 and the width to equal scale, but divide by 32
        w, h = img.size
        ratio = round((w / h) * 3)   # Lengthen the width three times and round it
        if ratio == 0:
            ratio = 1
        if ratio > self.max_ratio:
            ratio = self.max_ratio
        h_new = 32
        w_new = h_new * ratio
        img_resize = img.resize((w_new, h_new), Image.BILINEAR)

        # padding the right half of the picture so that the width / height ratio is fixed = self.max_ratio
        img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
        img_padd.paste(img_resize, (0, 0))

        # Random color transformation
        img_input = self.color_trans(img_padd)
        # Normalize
        img_input = self.trans_Normalize(img_input)

        # ----------------
        # label processing
        # ----------------

        # Construct mask of encoder
        encode_mask = [1] * ratio + [0] * (self.max_ratio - ratio)
        encode_mask = torch.tensor(encode_mask)
        encode_mask = (encode_mask != 0).unsqueeze(0)

        # Construct ground truth label
        gt = []
        gt.append(1)    # Add the sentence start character first
        for lbl in lbl_str:
        for i in range(len(lbl_str), self.sequence_len):   # Remove the start and end characters, and the lbl length is sequence_len, the remaining padding
        # Truncate to the preset maximum sequence length
        gt = gt[:self.sequence_len]

        # Input of decoder
        decode_in = gt[:-1]
        decode_in = torch.tensor(decode_in)
        # Output of decoder
        decode_out = gt[1:]
        decode_out = torch.tensor(decode_out)
        # mask of decoder 
        decode_mask = self.make_std_mask(decode_in, self.pad)
        # Number of valid tokens
        ntokens = (decode_out != self.pad).data.sum()

        return img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens

    def make_std_mask(tgt, pad):
        Create a mask to hide padding and future words.
        padd and future words All in mask 0 in
        tgt_mask = (tgt != pad)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        tgt_mask = tgt_mask.squeeze(0)   # The shape of the returned value of subsequence is (1, N, N)
        return tgt_mask

    def __len__(self):
        return len(self.data)

The above code also involves several details related to label processing, which belongs to the logic related to Transformer training. Here is another brief mention:


Because we have adjusted the size of the image and padded the image according to the needs, and the padded position does not contain effective information, we need to construct the corresponding encode according to the padding proportion_ Mask to let the transformer ignore this meaningless area during calculation.

label processing

The prediction tags used in this experiment are basically the same as those used in machine translation model training, so there is little difference in processing methods.

In label processing, the characters in the label are converted into their corresponding id, and the start character is added at the beginning of the sentence, the end character is added at the end of the sentence, and the sequence is not satisfied_ When len length is, padding (0 filling) is performed at the remaining position.


In general, we will use the sequence of label in the decoder_ Len generates a mask in the form of upper triangular matrix, and each row of the mask can control the current time_ In step, the decoder is only allowed to obtain the character information before the current step time, and it is prohibited to obtain the character information at the future time, which prevents cheating in model training.

decode_mask passes through a special function make_std_mask().

At the same time, the label production of decoder also needs to mask the padding part, so decode_ The mask should also be written as False at the position where the label is padded.

Generated decode_ The mask is shown in the following figure:

The above is all the details of building a Dataset, and then we can build a DataLoader for training

# Construct dataloader
max_ratio = 8    # The maximum value of width / height during image preprocessing. If it does not exceed the guaranteed proportion, it will be forcibly compressed
train_dataset = Recognition_Dataset(train_data_online, lbl2id_map, sequence_len, max_ratio, pad=0)
valid_dataset = Recognition_Dataset(valid_data_online, lbl2id_map, sequence_len, max_ratio, pad=0)
# loader size info:
# --> img_input: [batch_size, c, h, w] --> [64, 3, 32, 32*8*3]
# --> encode_ Mask: [batch_size, H / 32, w / 32] -- > [64, 1, 24] the backbone adopts 32 times down sampling, so divide by 32
# --> decode_in: [bs, sequence_len-1] --> [64, 20]
# --> decode_out: [bs, sequence_len-1] --> [64, 20]
# --> decode_mask: [bs, sequence_len-1, sequence_len-1] --> [64, 20, 20]
# --> ntokens: [bs] --> [64]
train_loader = torch.utils.data.DataLoader(train_dataset,
valid_loader = torch.utils.data.DataLoader(valid_dataset,

4.3 model construction

Code through make_ocr_model and OCR_ The encoderdecoder class completes the construction of the model structure.

From make_ ocr_ The model function looks like it first calls Resnet-18 pre trained in pytorch as the backbone to extract image features. It can also be adjusted to other networks according to its own needs, but it needs to focus on the down sampling multiple of the network and the channel of the last layer of feature map_ Num, the parameters of relevant modules need to be adjusted synchronously. After that, OCR_ was called. The encoderdecoder class completes the construction of the transformer. Finally, the model parameters are initialized.

In OCR_ In the encoderdecoder class, this class is equivalent to the assembly line of each basic component of a transformer, including encoder and decoder. Its initial parameters are the existing basic components, and its basic component codes are in the transformer.py file, which will not be described too much in this paper.

Let's review how the image is constructed as the input of the Transformer after passing through the backbone:

After the image passes through the backbone, a feature map with dimension [batch_size, 512, 1, 24] will be output_ On the premise of size, each image will get 1 with 512 channels as shown below × 24, as shown in the red box in the figure, the eigenvalues at the same position of different channels are spliced to form a new vector and used as the input of a time step. At this time, the variable constructs the input with the dimension of [batch_size, 24, 512], which meets the input requirements of the Transformer.

Let's take a look at the complete code for constructing the model:

# model structure 
class OCR_EncoderDecoder(nn.Module):
    A standard Encoder-Decoder architecture.
    Base for this and many other models.
    def __init__(self, encoder, decoder, src_embed, src_position, tgt_embed, generator):
        super(OCR_EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed    # input embedding module
        self.src_position = src_position
        self.tgt_embed = tgt_embed    # ouput embedding module
        self.generator = generator    # output generation module

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        # src --> [bs, 3, 32, 768]  [bs, c, h, w]
        # src_mask --> [bs, 1, 24]  [bs, h/32, w/32]
        memory = self.encode(src, src_mask)
        # memory --> [bs, 24, 512]
        # tgt --> decode_in [bs, 20]  [bs, sequence_len-1]
        # tgt_mask --> decode_mask [bs, 20]  [bs, sequence_len-1]
        res = self.decode(memory, src_mask, tgt, tgt_mask)  # [bs, 20, 512]
        return res

    def encode(self, src, src_mask):
        # feature extract
        # src --> [bs, 3, 32, 768]
        src_embedds = self.src_embed(src)
        # resnet18 is used here as the backbone output -- > [batchsize, C, h, w] -- > [BS, 512, 1, 24]
        # Set src_embedds is processed by shape(bs, model_dim, 1, max_ratio) into the input shape(bs, time step, model_dim) expected by the transformer
        # [bs, 512, 1, 24] --> [bs, 24, 512]
        src_embedds = src_embedds.squeeze(-2)
        src_embedds = src_embedds.permute(0, 2, 1)

        # position encode
        src_embedds = self.src_position(src_embedds)  # [bs, 24, 512]

        return self.encoder(src_embedds, src_mask)  # [bs, 24, 512]

    def decode(self, memory, src_mask, tgt, tgt_mask):
        target_embedds = self.tgt_embed(tgt)  # [bs, 20, 512]
        return self.decoder(target_embedds, memory, src_mask, tgt_mask)

def make_ocr_model(tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    Build model
        tgt_vocab: Output dictionary size
        N: Number of encoder and decoder stack base modules
        d_model: In the model embedding of size,Default 512
        d_ff: FeedForward Layer In layer embedding of size,Default 2048
        h: MultiHeadAttention The number of multiple heads in must be d_model to be divisible by
        dropout: dropout Ratio of
    c = copy.deepcopy

    # resnet18 pre trained in torch is used as feature extraction network, backbone
    backbone = models.resnet18(pretrained=True)
    backbone = nn.Sequential(*list(backbone.children())[:-2])    # Remove the last two layers (global average pooling and fc layer)

    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
 # Build model
    model = OCR_EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))  # The generator here is not called inside the class

    # Initialize parameters with Glorot / fan_avg.
    for child in model.children():
        if child is backbone:
            # Set the weight of the backbone to not calculate the gradient
            for param in child.parameters():
                param.requires_grad = False
            # The pre trained backbone is not initialized randomly, and the other modules are initialized randomly
        for p in child.parameters():
            if p.dim() > 1:
    return model

The transformer model can be easily built through the above two classes:

# build model
# use transformer as ocr recognize model
# OCR built here_ The model does not contain a Generator
tgt_vocab = len(lbl2id_map.keys()) 
d_model = 512
ocr_model = make_ocr_model(tgt_vocab, N=5, d_model=d_model, d_ff=2048, h=8, dropout=0.1)

4.4 model training

Before model training, we also need to define model evaluation criteria, iterative optimizer, etc. In this experiment, label smoothing, network training warm-up and other strategies are used in training. The calling functions of the above strategies are in train_ In the utils.py file, the principle and code implementation of the above two methods are not involved here.

Label smoothing can convert the original hard label into soft label, so as to increase the fault tolerance of the model and improve the generalization ability of the model. The LabelSmoothing() function in the code implements label smoothing, and the relative entropy function is used internally to calculate the loss between the predicted value and the real value.

Warmup strategy can effectively control the optimizer learning rate in the process of model training, automatically control the model learning rate from small increase to gradually decrease, help the model more stable in training and realize the rapid convergence of loss. The NoamOpt() function in the code realizes the warmup control, and the Adam optimizer is used to automatically adjust the learning rate with the number of iterations.

# train prepare
criterion = LabelSmoothing(size=tgt_vocab, padding_idx=0, smoothing=0.0)  # label smoothing
optimizer = torch.optim.Adam(ocr_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 1, 400, optimizer)  # warmup

The code of the model training process is as follows. Every 10 epochs are trained for verification, and the calculation process of a single epoch is encapsulated in run_ In the epoch() function.

# train & valid ...
for epoch in range(nrof_epochs):
    print(f"\nepoch {epoch}")
    print("train...")  # train
    loss_compute = SimpleLossCompute(ocr_model.generator, criterion, model_opt)
    train_mean_loss = run_epoch(train_loader, ocr_model, loss_compute, device)

    if epoch % 10 == 0:
        print("valid...")  # verification
        valid_loss_compute = SimpleLossCompute(ocr_model.generator, criterion, None)
        valid_mean_loss = run_epoch(valid_loader, ocr_model, valid_loss_compute, device)
        print(f"valid loss: {valid_mean_loss}")

        # save model
        torch.save(ocr_model.state_dict(), './trained_model/ocr_model.pt')

The SimpleLossCompute() class implements the loss calculation of the output result of the transformer. When using this class for direct calculation, the class needs to receive three parameters (x, y, norm). X is the result output by the decoder, y is the label data, norm is the normalization coefficient of loss, and all the effective token s in the batch can be used. It can be seen that the construction of all networks of transformer is being completed here to realize the circulation of data computing flow.

run_ The epoch () function completes all the work of an epoch training, including data loading, model reasoning, loss calculation and direction propagation, and prints the training process information.

def run_epoch(data_loader, model, loss_compute, device=None):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_loader):
        img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
        img_input = img_input.to(device)
        encode_mask = encode_mask.to(device)
        decode_in = decode_in.to(device)
        decode_out = decode_out.to(device)
        decode_mask = decode_mask.to(device)
        ntokens = torch.sum(ntokens).to(device)

        out = model.forward(img_input, decode_in, encode_mask, decode_mask)
        # Out -- > [BS, 20, 512] prediction results
        # decode_ Out -- > [BS, 20] actual results
        # Actual valid characters in tokens -- > tag

        loss = loss_compute(out, decode_out, ntokens)  # Loss calculation
        total_loss += loss
        total_tokens += ntokens
        tokens += ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
    def __call__(self, x, y, norm):
        norm: loss Normalization coefficient of, using batch All valid in token Just count
        # X -- > out -- > [BS, 20, 512] prediction results
        # y --> decode_ Out -- > [BS, 20] actual results
        # Actual valid characters in Norm -- > tokens -- > tag
        x = self.generator(x)
        # label smoothing needs to correspond to dimension changes
        x_ = x.contiguous().view(-1, x.size(-1))  # [20bs, 512]
        y_ = y.contiguous().view(-1)  # [20bs]
        loss = self.criterion(x_, y_)
        loss /= norm
        if self.opt is not None:
        #return loss.data[0] * norm 
        return loss.item() * norm

4.5 greedy decoding

For convenience, we use the simplest greedy decoding to directly predict the OCR results. Because the model can only produce one output at a time, we select the character corresponding to the highest probability in the output probability distribution as the result of this prediction, and then predict the next character. This is the so-called greedy decoding. See greedy in the code_ The decode() function.

In the experiment, each image is used as the input of the model, greedy decoding is carried out one by one, and the prediction accuracy of the training set and the verification set is finally given.

# After the training, the greedy decoding method is used to infer the training set and verification set, and the accuracy is counted

print("greedy decode trainset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(train_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)
    encode_mask = encode_mask.to(device)
    # Get single image information
    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]
  # Greedy decoding
        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()
  # Judge whether the prediction is correct
        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # Print case with wrong prediction
        total_correct_rate = total_correct_num / total_img_num * 100
        print(f"total correct rate of trainset: {total_correct_rate}%")

# Same as training set decoding code
print("greedy decode validset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(valid_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)
    encode_mask = encode_mask.to(device)

    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]

        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()

        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # Print case with wrong prediction
        total_correct_rate = total_correct_num / total_img_num * 100
        print(f"total correct rate of validset: {total_correct_rate}%")

greedy_ The decode () function is implemented.

# greedy decode
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
    memory = model.encode(src, src_mask)
    # ys represents the currently generated sequence. Initially, it is a sequence containing only one starting character, and the prediction result is continuously appended to the end of the sequence
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).long()
    for i in range(max_len-1):
        out = model.decode(memory, src_mask,
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        next_word = torch.ones(1, 1).type_as(src.data).fill_(next_word).long()
        ys = torch.cat([ys, next_word], dim=1)

        next_word = int(next_word)
        if next_word == end_symbol:
        #ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    ys = ys[0, 1:]
    return ys

def judge_is_correct(pred, label):
    # Judge whether the predicted results of the model are consistent with the label
    pred_len = pred.shape[0]
    label = label[:pred_len]
    is_correct = 1 if label.equal(pred) else 0
    return is_correct

Run the following command to start the training with one click:

python ocr_by_transformer.py

The training log is as follows:

epoch 0
Epoch Step: 1 Loss: 5.142612 Tokens per Sec: 852.649109
Epoch Step: 51 Loss: 3.064528 Tokens per Sec: 2709.471436
Epoch Step: 1 Loss: 3.018526 Tokens per Sec: 1413.900391
valid loss: 2.7769546508789062

epoch 1
Epoch Step: 1 Loss: 3.440590 Tokens per Sec: 1303.567993
Epoch Step: 51 Loss: 2.711708 Tokens per Sec: 2743.414307


epoch 1499
Epoch Step: 1 Loss: 0.005739 Tokens per Sec: 1232.602783
Epoch Step: 51 Loss: 0.013249 Tokens per Sec: 2765.866211

greedy decode trainset
tensor([17, 32, 18, 19, 31, 50, 30, 10, 30, 10, 17, 32, 41, 55, 55, 55,  2,  0,
         0,  0])
tensor([17, 32, 18, 19, 31, 50, 30, 10, 30, 10, 17, 32, 41, 55, 55, 55, 55, 55,
        55, 55])
tensor([17, 32, 18, 19, 31, 50, 30, 10, 17, 32, 41, 55, 55,  2,  0,  0,  0,  0,
         0,  0])
tensor([17, 32, 18, 19, 31, 50, 30, 10, 17, 32, 41, 55, 55, 55, 55,  2])
total correct rate of trainset: 99.95376791493297%

greedy decode validset
tensor([10, 11, 28, 27, 25, 11, 47, 45,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
tensor([10, 11, 28, 27, 25, 11, 62,  2])


tensor([20, 12, 24, 35,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
tensor([20, 12, 21, 12, 22, 23, 34,  2])
total correct rate of validset: 92.72088353413655%


The above is the whole content of this article. Congratulations on reading here!

In this paper, we first introduce a word recognition task data set in ICDAR2015, then briefly analyze the characteristics of the data, and construct a character mapping table for recognition. After that, we focus on the motivation and ideas of introducing transformer to solve OCR tasks, and introduce the details in detail combined with the code. Finally, we have roughly passed some training related logic and code.

This article is mainly to help you open your mind and understand other application points of transformer in CV except as a backbone. For the implementation code of Tranformer model itself, refer to The Annotated Transformer. How to apply it to OCR is completely realized in combination with the author's personal understanding, which can not be guaranteed to be applied to more complex engineering problems. If you have any questions about any details in the article, please contact us for discussion. If there are errors, please also point out them.

I hope you can gain something after reading!

Posted on Wed, 01 Dec 2021 22:14:28 -0500 by AlGale