End to end indefinite length text recognition CRNN code implementation

CRNN is a very classical and widely used recognition algorithm in OCR field. Its theoretical basis can refer to me Last article , this paper will focus on the implementation process of CRNN code and the recognition effect.

data processing

Using image processing technology, we manually generate a large number of text images, a total of 3.6 million image samples, the effect is as follows:

We divided the training set and test set (10:1) and stored them separately as two text files:

The label format in the text file is as follows:

We get the most original data set. In the image deep learning training, we usually transform the original data set into lmdb format to facilitate the subsequent network training. So we also need to transform the dataset into lmdb format. The following code is used for lmdb format conversion. The idea is relatively simple. First, read in the image and the corresponding text label, use the dictionary to store the combination (CACHE), and then use the put function of lmdb package to write the K and V stored in the dictionary into lmdb format (put once when there are 1000 elements in the cache).


import lmdb
import cv2
import numpy as np
import os


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    try:
        imageBuf = np.fromstring(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        imgH, imgW = img.shape[0], img.shape[1]
    except:
        return False
    else:
        if imgH * imgW == 0:
            return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert (len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = ''.join(imagePathList[i]).split()[0].replace('\n', '').replace('\r\n', '')
        # print(imagePath)
        label = ''.join(labelList[i])
        print(label)
        # if not os.path.exists(imagePath):
        #     print('%s does not exist' % imagePath)
        #     continue

        with open('.' + imagePath, 'r') as f:
            imageBin = f.read()

        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
        print(cnt)
    nSamples = cnt - 1
    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


OUT_PATH = '../crnn_train_lmdb'
IN_PATH = './train.txt'

if __name__ == '__main__':
    outputPath = OUT_PATH
    if not os.path.exists(OUT_PATH):
        os.mkdir(OUT_PATH)
    imgdata = open(IN_PATH)
    imagePathList = list(imgdata)

    labelList = []
    for line in imagePathList:
        word = line.split()[1]
        labelList.append(word)
    createDataset(outputPath, imagePathList, labelList)


We can get lmdb of training set and test set by running the above code

In the data preparation part, there is another operation that needs to be emphasized, that is, the digitization of text labels, that is, we use numbers to represent each text (Chinese characters, English letters, punctuation marks). For example, the id of "I" is 1, the id of "l" is 1000, and the id of "I" is 1000 The corresponding id is 90, and so on. This kind of encoding and decoding work can be stored using the dictionary data structure. When training, encode the tag first, and when predicting, decode the network output results into text output.


class strLabelConverter(object):
    """Convert between str and label.

    NOTE:
        Insert `blank` to the alphabet for CTC.

    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """

        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')

            length.append(len(item))
            for char in item:

                index = self.dict[char]
                result.append(index)

        text = result
        # print(text,length)
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

Network design

According to CRNN's paper description, CRNN is composed of CNN - "RNN -" CTC three parts, corresponding to convolution layer, circulation layer and transcription layer respectively. Firstly, CNN is used to extract the underlying features, RNN adopts BiLSTM to learn the correlation sequence information and predict the label distribution, CTC is used to align the sequence and output the prediction results.

In order to import features into the current layers, do the following:

  • First, the image will be zoomed to 32 × W × 3 size
  • After CNN, it becomes 1 × (W/4) × 512
  • Then for LSTM, set T=(W/4), D=512, and then input the feature into LSTM.

The above is an ideal training operation, but the network input mentioned in CRNN paper is a normalized gray image of 100 × 32 size, that is, the height is unified to 32 pixels. The following is the deep neural network structure diagram of CRNN. CNN adopts the classic VGG16. It is worth noting that CRNN adopts a 1 × 2 rectangular pooling window (w × h) in the 3rd and 4th max pooling layer of VGG16, which is different from the classic VGG16's 2 × 2 square pooling window. This change is because most of the text images are small in height and long in width, so its features Map is also a rectangle shape of high, small, wide and long. If 1 × 2 pooling window is used, it is more suitable for English letter recognition (such as distinguishing i and l). In VGG16, the BatchNormalization module is also introduced to accelerate the convergence of the model. It is also worth noting that the input of CRNN is grayscale image, that is, the image depth is 1. The output of CNN is the eigenvector of 512x1x16 (c × h × w).

Next, the RNN layer is analyzed. In the RNN part, bidirectional LSTM is used, and the number of hidden layer units is 256. In CRNN, BiLSTM is used to form the RNN layer. The output dimension of RNN layer will be (s, B, class)_ Num), where class_num is the total number of text categories.

It should be noted that the input received by the LSTM unit in the python must be three-dimensional Tensors. The meaning of each dimension cannot be mistaken. The first dimension represents sequence structure, the second dimension represents Mini batch structure, and the third dimension represents elements of input. If the small block structure is not applicable in the application, the dimension in the input tensor can be set to 1, but it must reflect this dimension.

Input of LSTM

input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. 
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len  -> Sequence length
b:batch
c:input_size   Number of input features 

According to the input requirements of LSTM, we need to adjust the output of CNN to [SEQ]_ len, batch, input_ Size form, the following is the specific operation: first remove the h dimension with the squeeze function, and then adjust the order of dimensions with the permute function, that is, from the original [w, b, c] to [SEQ]_ len, batch, input_ Size], the specific size is [16,batch,512], after adjustment, the matrix can be sent to RNN layer.


x = self.cnn(x)
b, c, h, w = x.size()
# print(x.size()): b,c,h,w
assert h == 1   # "the height of conv must be 1"
x = x.squeeze(2)  # remove h dimension, b *512 * width
x = x.permute(2, 0, 1)  # [w, b, c] = [seq_len, batch, input_size]
x = self.rnn(x)

The output format of RNN layer is as follows. Because we use bidirectional BiLSTM, the output dimension will be hidden_unit * 2

Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size) 

And then we do it through linear transformation self.embedding1 = torch.nn.Linear (hidden_ The output dimension of unit * 2, 512) is changed to 512 again and continues to the second LSTM layer. Continue wiring operation after the second LSTM layer torch.nn.Linear(hidden_unit * 2, class_num) makes the output of the whole RNN layer the total number of text categories.

import torch
import torch.nn.functional as F


class Vgg_16(torch.nn.Module):

    def __init__(self):
        super(Vgg_16, self).__init__()
        self.convolution1 = torch.nn.Conv2d(1, 64, 3, padding=1)
        self.pooling1 = torch.nn.MaxPool2d(2, stride=2)
        self.convolution2 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.pooling2 = torch.nn.MaxPool2d(2, stride=2)
        self.convolution3 = torch.nn.Conv2d(128, 256, 3, padding=1)
        self.convolution4 = torch.nn.Conv2d(256, 256, 3, padding=1)
        self.pooling3 = torch.nn.MaxPool2d((1, 2), stride=(2, 1)) # notice stride of the non-square pooling
        self.convolution5 = torch.nn.Conv2d(256, 512, 3, padding=1)
        self.BatchNorm1 = torch.nn.BatchNorm2d(512)
        self.convolution6 = torch.nn.Conv2d(512, 512, 3, padding=1)
        self.BatchNorm2 = torch.nn.BatchNorm2d(512)
        self.pooling4 = torch.nn.MaxPool2d((1, 2), stride=(2, 1))
        self.convolution7 = torch.nn.Conv2d(512, 512, 2)

    def forward(self, x):
        x = F.relu(self.convolution1(x), inplace=True)
        x = self.pooling1(x)
        x = F.relu(self.convolution2(x), inplace=True)
        x = self.pooling2(x)
        x = F.relu(self.convolution3(x), inplace=True)
        x = F.relu(self.convolution4(x), inplace=True)
        x = self.pooling3(x)
        x = self.convolution5(x)
        x = F.relu(self.BatchNorm1(x), inplace=True)
        x = self.convolution6(x)
        x = F.relu(self.BatchNorm2(x), inplace=True)
        x = self.pooling4(x)
        x = F.relu(self.convolution7(x), inplace=True)
        return x  # b*512x1x16


class RNN(torch.nn.Module):
    def __init__(self, class_num, hidden_unit):
        super(RNN, self).__init__()
        self.Bidirectional_LSTM1 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)
        self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)
        self.Bidirectional_LSTM2 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)
        self.embedding2 = torch.nn.Linear(hidden_unit * 2, class_num)

    def forward(self, x):
        x = self.Bidirectional_LSTM1(x)   # LSTM output: output, (h_n, c_n)
        T, b, h = x[0].size()   # x[0]: (seq_len, batch, num_directions * hidden_size)
        x = self.embedding1(x[0].view(T * b, h))  # pytorch view() reshape as [T * b, nOut]
        x = x.view(T, b, -1)  # [16, b, 512]
        x = self.Bidirectional_LSTM2(x)
        T, b, h = x[0].size()
        x = self.embedding2(x[0].view(T * b, h))
        x = x.view(T, b, -1)
        return x  # [16,b,class_num]


# output: [s,b,class_num]
class CRNN(torch.nn.Module):
    def __init__(self, class_num, hidden_unit=256):
        super(CRNN, self).__init__()
        self.cnn = torch.nn.Sequential()
        self.cnn.add_module('vgg_16', Vgg_16())
        self.rnn = torch.nn.Sequential()
        self.rnn.add_module('rnn', RNN(class_num, hidden_unit))

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        # print(x.size()): b,c,h,w
        assert h == 1   # "the height of conv must be 1"
        x = x.squeeze(2)  # remove h dimension, b *512 * width
        x = x.permute(2, 0, 1)  # [w, b, c] = [seq_len, batch, input_size]
        # x = x.transpose(0, 2)
        # x = x.transpose(1, 2)
        x = self.rnn(x)
        return x

Loss function design

We have just completed the design of CNN layer and RNN layer, and now we start to design the transcription layer, which translates the output of RNN layer into the final recognition result, so as to realize the indefinite length of text recognition. pytorch does not have a built-in CTC loss, so you can only go to Github to download the CTC loss implemented by others to complete the design of the loss function. Install CTC loss as follows:

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding/
python setup.py install
cd ../build
cp libwarpctc.so ../../usr/lib

After installation, we can call CTC loss directly, and use a small example to illustrate the usage of CTC loss.

import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)  # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
CTCLoss(size_average=False, length_average=False)
    # size_average (bool): normalize the loss by the batch size (default: False)
    # length_average (bool): normalize the loss by the total number of frames in the batch. If True, supersedes size_average (default: False)

forward(acts, labels, act_lens, label_lens)
    # acts: Tensor of (seqLength x batch x outputDim) containing output activations from network (before softmax)
    # labels: 1 dimensional Tensor containing all the targets of the batch in one large sequence
    # act_lens: Tensor of size (batch) containing size of each output sequence from the network
    # label_lens: Tensor of (batch) containing label length of each example

As can be seen from the above code, the input of CTCLoss is [probs, labels, probs_sizes, label_sizes], i.e. prediction results, labels, number of prediction results and number of labels. Let's follow this example to design the CTC LOSS of CRNN.


preds = net(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))  # preds.size(0)=w=16
cost = criterion(preds, text, preds_size, length) / batch_size   # The length here is the list containing the length of each text label, divided by batch_size for average loss
cost.backward()

Network training design

Next, we need to improve the specific training process. We also wrote a trainBatch function to update the gradient in the form of bacth.

def trainBatch(net, criterion, optimizer, train_iter):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    lib.dataset.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    lib.dataset.loadData(text, t)
    lib.dataset.loadData(length, l)

    preds = net(image)
    #print("preds.size=%s" % preds.size)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))  # preds.size(0)=w=22
    cost = criterion(preds, text, preds_size, length) / batch_size  # length= a list that contains the len of text label in a batch
    net.zero_grad()
    cost.backward()
    optimizer.step()
    return cost

The whole network training process is as follows: CTC-LOSS object - > CRNN network object - > image, text, Len's sensor initialization - > optimizer initialization, then start to cycle each epoch, specify the number of iterations for model verification and model saving. CRNN paper mentioned that the optimizer used is Adadelta, but through my experiment, Adadelta's convergence speed is very slow, so we use RMSprop optimizer instead, and the convergence speed of the model is greatly improved.


    criterion = CTCLoss()

    net = Net.CRNN(n_class)
    print(net)

    net.apply(lib.utility.weights_init)

    image = torch.FloatTensor(Config.batch_size, 3, Config.img_height, Config.img_width)
    text = torch.IntTensor(Config.batch_size * 5)
    length = torch.IntTensor(Config.batch_size)

    if cuda:
        net.cuda()
        image = image.cuda()
        criterion = criterion.cuda()

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    loss_avg = lib.utility.averager()

    optimizer = optim.RMSprop(net.parameters(), lr=Config.lr)
    #optimizer = optim.Adadelta(net.parameters(), lr=Config.lr)
    #optimizer = optim.Adam(net.parameters(), lr=Config.lr,
                           #betas=(Config.beta1, 0.999))

    for epoch in range(Config.epoch):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in net.parameters():
                p.requires_grad = True
            net.train()

            cost = trainBatch(net, criterion, optimizer, train_iter)
            loss_avg.add(cost)
            i += 1

            if i % Config.display_interval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, Config.epoch, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % Config.test_interval == 0:
                val(net, test_dataset, criterion)

            # do checkpointing
            if i % Config.save_interval == 0:
                torch.save(
                    net.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(Config.model_dir, epoch, i))

Training process and test design

The following figure shows the CRNN training process. The number of text categories is 6732. There are 20 epoch and batch in total_ Szie is set to 64, so there are 51244 iterations / epoch.

When iterating four epoch s, the loss is reduced to Zero point one Left and right, acc goes up to zero point nine eight .

Next, we design the code for the inference prediction part. First, we need to initialize the CRNN network, load the trained model, read in the image to be predicted and resize it to a gray-scale image with a height of 32, then send the image to the network, and finally decode the network output to text for output.


import time
import torch
import os
from torch.autograd import Variable
import lib.convert
import lib.dataset
from PIL import Image
import Net.net as Net
import alphabets
import sys
import Config

os.environ['CUDA_VISIBLE_DEVICES'] = "4"

crnn_model_path = './bs64_model/netCRNN_9_48000.pth'
IMG_ROOT = './test_images'
running_mode = 'gpu'
alphabet = alphabets.alphabet
nclass = len(alphabet) + 1


def crnn_recognition(cropped_image, model):
    converter = lib.convert.strLabelConverter(alphabet)  # Label conversion

    image = cropped_image.convert('L')  # Image graying

    ### Testing images are scaled to have height 32. Widths are
    # proportionally scaled with heights, but at least 100 pixels
    w = int(image.size[0] / (280 * 1.0 / Config.infer_img_w))
    #scale = image.size[1] * 1.0 / Config.img_height
    #w = int(image.size[0] / scale)

    transformer = lib.dataset.resizeNormalize((w, Config.img_height))
    image = transformer(image)
    if torch.cuda.is_available():
        image = image.cuda()
    image = image.view(1, *image.size())
    image = Variable(image)

    model.eval()
    preds = model(image)

    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)  # Predictive output decoded into text
    print('results: {0}'.format(sim_pred))


if __name__ == '__main__':

    # crnn network
    model = Net.CRNN(nclass)
    
    # Load the trained model. The loading mode of CPU and GPU is different. They need to be handled separately
    if running_mode == 'gpu' and torch.cuda.is_available():
        model = model.cuda()
        model.load_state_dict(torch.load(crnn_model_path))
    else:
        model.load_state_dict(torch.load(crnn_model_path, map_location='cpu'))

    print('loading pretrained model from {0}'.format(crnn_model_path))

    files = sorted(os.listdir(IMG_ROOT))  # Sort by filename
    for file in files:
        started = time.time()
        full_path = os.path.join(IMG_ROOT, file)
        print("=============================================")
        print("ocr image is %s" % full_path)
        image = Image.open(full_path)

        crnn_recognition(image, model)
        finished = time.time()
        print('elapsed time: {0}'.format(finished - started))


Identification effect and summary

First of all, I extract several images from the test set and send them to the model recognition, and the recognition is all correct.

I also randomly cut a paragraph of text image from some document pictures and scanned images and send it to our model for recognition. The recognition effect is also very good, and the basic recognition is correct, which shows that the model generalization ability is very strong.

I also intercepted the text image on the VAT scanned invoice to see if our model can still show stable recognition effect:

Here's a small summary: for the end-to-end indefinite length of text recognition, CRNN is the most classic recognition algorithm, and the actual combat seems to be very good. As can be seen from the above recognition results, although the data set we use for training is generated by ourselves, our model has very good recognition results for pdf documents, scanned images, etc. if we need to continue to improve the recognition of text images in specific fields, we can directly add a large number of such images for training. The complete code of CRNN can refer to my Github.

Tags: network github Python git

Posted on Sat, 16 May 2020 02:53:26 -0400 by jkohns