[practice] Application of ELECTRA in NLP field in symbol prediction

Punctuation prediction based on ELECTRA

1. Resources

⭐ ⭐ ⭐ Welcome to order a small one Star support! ⭐ ⭐ ⭐

Open source is not easy. I hope you can support it~

  • For more Transformer Models (BERT, ERNIE, ViT, DeiT, swing transformer, etc.) and in-depth learning materials in CV and NLP, please refer to: awesome-DeepLearning

  • For more NLP models (BERT Series), please refer to: PaddleNLP

2. Principle interpretation

2.1 introduction

Masked language model (MLM), similar to BERT, uses [MASK] to replace some characters in the text through pre training method, which destroys the original input of the text, and then trains the model to reconstruct the original text. Although they produce good results in downstream NLP tasks, they usually require a lot of computation to be effective. As an alternative, the author proposes a more effective pre training task, called Replaced Token Detection(RTD), character replacement detection. RTD method is not to MASK the input, but to destroy the input by using the generation network to generate some reasonable replacement characters. Then, we train a discriminator model, which can predict whether the current character has been replaced by the language model. The experimental results show that this new pre training task is more effective than MLM, because the task is defined on all text inputs, not only a small part covered up. When the model size, data and computing power are the same, the context representation learned by RTD method is much better than that learned by BERT.

In the above figure, the figure on the left is the enlarged version on the right, the vertical axis is dev GLUE fraction, the horizontal axis is FLOPs(floating point operations), and the calculation amount statistics of floating point numbers provided in Tensorflow. It can be seen from the above figure that ELECTRA of the same magnitude has been rolling BERT, and the training steps are longer, reaching the effect of the SOTA model RoBERTa at that time. As can be seen from the graph on the left, there is still room for ELECTRA effect to continue to rise.

2.2 model structure

ELECTRA's greatest contribution is to propose a new pre training task and framework, which is also mentioned in the above introduction. The generative MLM pre training task is changed into a discriminant RTD task, and then judge whether the current token has been replaced. So here's the problem. Suppose I randomly replace some words in the input, and then let BERT predict whether it has been replaced. Is this feasible? Some people have done experiments, but the effect is not very good, because random replacement is too simple.

The author uses a MLM G-BERT to transform the input sentence, and then throws it to D-BERT to judge which word has been modified, as follows:

In downstream tasks, the Discriminator part is generally used for fine tune.

In the following chapters, we will introduce the principle and code implementation of each part of ELECTRA in detail. ELECTRA includes embedding, ELECTRAModel, Generator and Discriminator. Embedding is the word embedding module of the whole model

2.2.1 ELECTRA model

from paddlenlp.transformers import PretrainedModel
from paddlenlp.transformers import ElectraPretrainedModel
import paddle
import paddle.nn as nn
import paddle.tensor as tensor
import paddle.nn.functional as F

Before implementing the ELECTRA module, we need to implement the Embedding module, which includes Input Embedding, segment Embedding and position Embedding

class ElectraEmbeddings(nn.Layer):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, vocab_size, embedding_size, hidden_dropout_prob,
                 max_position_embeddings, type_vocab_size):
        super(ElectraEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings,
                                                embedding_size)
        self.token_type_embeddings = nn.Embedding(type_vocab_size,
                                                  embedding_size)

        self.layer_norm = nn.LayerNorm(embedding_size, epsilon=1e-12)
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
            position_ids.stop_gradient = True

        if token_type_ids is None:
            token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
        # input embedding 
        input_embeddings = self.word_embeddings(input_ids)
        # position embedding
        position_embeddings = self.position_embeddings(position_ids)
        # segment embedding
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = input_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

The next step is to implement the ELECTRA model, which is essentially a Transformer Encoder, which is no different from the structure of BERT

class ElectraModel(ElectraPretrainedModel):
    def __init__(self, vocab_size, embedding_size, hidden_size,
                 num_hidden_layers, num_attention_heads, intermediate_size,
                 hidden_act, hidden_dropout_prob, attention_probs_dropout_prob,
                 max_position_embeddings, type_vocab_size, initializer_range,
                 pad_token_id):
        super(ElectraModel, self).__init__()
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.embeddings = ElectraEmbeddings(
            vocab_size, embedding_size, hidden_dropout_prob,
            max_position_embeddings, type_vocab_size)

        if embedding_size != hidden_size:
            self.embeddings_project = nn.Linear(embedding_size, hidden_size)

        encoder_layer = nn.TransformerEncoderLayer(
            hidden_size,
            num_attention_heads,
            intermediate_size,
            dropout=hidden_dropout_prob,
            activation=hidden_act,
            attn_dropout=attention_probs_dropout_prob,
            act_dropout=0)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        if attention_mask is None:
            attention_mask = paddle.unsqueeze(
                (input_ids == self.pad_token_id
                 ).astype(paddle.get_default_dtype()) * -1e9,
                axis=[1, 2])
        # Embedding
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids)

        if hasattr(self, "embeddings_project"):
            embedding_output = self.embeddings_project(embedding_output)
        # Transformer Encoder
        encoder_outputs = self.encoder(embedding_output, attention_mask)

        return encoder_outputs

2.2.2 Generator

Next is the Generator part of Electra. You need to load the previously implemented Electra, and then access a full connection layer later.

class ElectraGeneratorPredictions(nn.Layer):
    """Prediction layer for the generator, made up of two dense layers."""

    def __init__(self, embedding_size, hidden_size, hidden_act):
        super(ElectraGeneratorPredictions, self).__init__()

        self.layer_norm = nn.LayerNorm(embedding_size)
        self.dense = nn.Linear(hidden_size, embedding_size)
        self.act = get_activation(hidden_act)

    def forward(self, generator_hidden_states):
        hidden_states = self.dense(generator_hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.layer_norm(hidden_states)

        return hidden_states
class ElectraGenerator(ElectraPretrainedModel):
    def __init__(self, electra):
        super(ElectraGenerator, self).__init__()

        self.electra = electra
        self.generator_predictions = ElectraGeneratorPredictions(
            self.electra.config["embedding_size"],
            self.electra.config["hidden_size"],
            self.electra.config["hidden_act"])

        if not self.tie_word_embeddings:
            self.generator_lm_head = nn.Linear(
                self.electra.config["embedding_size"],
                self.electra.config["vocab_size"])
        else:
            self.generator_lm_head_bias = paddle.fluid.layers.create_parameter(
                shape=[self.electra.config["vocab_size"]],
                dtype=paddle.get_default_dtype(),
                is_bias=True)
        self.init_weights()

    def get_input_embeddings(self):
        return self.electra.embeddings.word_embeddings

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        generator_sequence_output = self.electra(input_ids, token_type_ids,
                                                 position_ids, attention_mask)

        prediction_scores = self.generator_predictions(
            generator_sequence_output)
        if not self.tie_word_embeddings:
            prediction_scores = self.generator_lm_head(prediction_scores)
        else:
            prediction_scores = paddle.add(paddle.matmul(
                prediction_scores,
                self.get_input_embeddings().weight,
                transpose_y=True),
                                           self.generator_lm_head_bias)

        return prediction_scores

2.2.3 Discriminator

The next step is the implementation of the discriminator, which is also connected to a full connection layer behind ELECTRA.

class ElectraDiscriminatorPredictions(nn.Layer):
    """Prediction layer for the discriminator, made up of two dense layers."""

    def __init__(self, hidden_size, hidden_act):
        super(ElectraDiscriminatorPredictions, self).__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dense_prediction = nn.Linear(hidden_size, 1)
        self.act = get_activation(hidden_act)

    def forward(self, discriminator_hidden_states):
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = self.act(hidden_states)
        logits = self.dense_prediction(hidden_states).squeeze()

        return logits
class ElectraDiscriminator(ElectraPretrainedModel):
    def __init__(self, electra):
        super(ElectraDiscriminator, self).__init__()

        self.electra = electra
        self.discriminator_predictions = ElectraDiscriminatorPredictions(
            self.electra.config["hidden_size"],
            self.electra.config["hidden_act"])
        self.init_weights()

    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        discriminator_sequence_output = self.electra(
            input_ids, token_type_ids, position_ids, attention_mask)

        logits = self.discriminator_predictions(discriminator_sequence_output)

        return logits

2.2.4 pre training loss function

Finally, the loss function of ELECTRA is trained. The loss function includes the loss of genrator and the loss of discriminator

class ElectraPretrainingCriterion(nn.Layer):
    def __init__(self, vocab_size, gen_weight, disc_weight):
        super(ElectraPretrainingCriterion, self).__init__()

        self.vocab_size = vocab_size
        self.gen_weight = gen_weight
        self.disc_weight = disc_weight
        self.gen_loss_fct = nn.CrossEntropyLoss(reduction='none')
        self.disc_loss_fct = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, generator_prediction_scores,
                discriminator_prediction_scores, generator_labels,
                discriminator_labels, attention_mask):
        # generator loss
        gen_loss = self.gen_loss_fct(
            paddle.reshape(generator_prediction_scores, [-1, self.vocab_size]),
            paddle.reshape(generator_labels, [-1]))
        # todo: we can remove 4 lines after when CrossEntropyLoss(reduction='mean') improved
        umask_positions = paddle.zeros_like(generator_labels).astype(
            paddle.get_default_dtype())
        mask_positions = paddle.ones_like(generator_labels).astype(
            paddle.get_default_dtype())
        mask_positions = paddle.where(generator_labels == -100, umask_positions,
                                      mask_positions)
        if mask_positions.sum() == 0:
            gen_loss = paddle.to_tensor([0.0])
        else:
            gen_loss = gen_loss.sum() / mask_positions.sum()

        # discriminator loss
        seq_length = discriminator_labels.shape[1]
        disc_loss = self.disc_loss_fct(
            paddle.reshape(discriminator_prediction_scores, [-1, seq_length]),
            discriminator_labels.astype(paddle.get_default_dtype()))
        if attention_mask is not None:
            umask_positions = paddle.ones_like(discriminator_labels).astype(
                paddle.get_default_dtype())
            mask_positions = paddle.zeros_like(discriminator_labels).astype(
                paddle.get_default_dtype())
            use_disc_loss = paddle.where(attention_mask, disc_loss,
                                         mask_positions)
            umask_positions = paddle.where(attention_mask, umask_positions,
                                           mask_positions)
            disc_loss = use_disc_loss.sum() / umask_positions.sum()
        else:
            total_positions = paddle.ones_like(discriminator_labels).astype(
                paddle.get_default_dtype())
            disc_loss = disc_loss.sum() / total_positions.sum()

        return self.gen_weight * gen_loss + self.disc_weight * disc_loss

2.3 punctuation prediction task

In the past, this research direction was called "sense boundary detection", and now it is called predict impulse or impulse restoration. Punctuation prediction using deep learning can be divided into three categories:

  • Method based on acoustic features. This kind of method predicts punctuation based on people's pause while speaking, but in the real ASR system, if there is an unnatural pause, its prediction ability will decline.
  • Method based on text features. Text data often have different types. For example, the model trained on the people's daily data set is difficult to work on some chat data sets.
  • A method combining text and acoustic features. This kind of method is effective, but the data set is required to have voice data and ASR transcripts at the same time, so the data is difficult to obtain.

In this experiment, Discriminator is used to do punctuation prediction task. Punctuation prediction is essentially a sequence annotation task. There are three kinds of punctuation marks predicted in this experiment: comma, period and question mark. Other types of punctuation can also be added if the reader is interested.

3. Code practice

3.1 import library package

import os
import xml.etree.ElementTree as ET
import codecs
from collections import Counter
import re
import ujson
import pandas as pd


import random
import time
import math
from functools import partial
import inspect
from tqdm import tqdm 
import collections
import numpy as np

import paddle
from paddle.io import DataLoader
from paddle.dataset.common import md5file

import paddlenlp as ppnlp
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.transformers import ElectraForTokenClassification, ElectraTokenizer
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.datasets import DatasetBuilder
from paddlenlp.utils.env import DATA_HOME

from sklearn import metrics # https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
from sklearn.metrics import classification_report

3.2 data preprocessing

Because the data set is in xml format after decompression, it needs to be extracted and then segmented to construct the data set.
The data set used this time is IWSLT12, which comes from http://hltc.cs.ust.hk/iwslt/index.html . We can also find it in the Aistudio public dataset.

!unzip -o IWSLT12.zip -d IWSLT12
data_path = "IWSLT12/"
file_path = data_path + "IWSLT12.TALK.dev2010.en-fr.en.xml" 
xmlp = ET.XMLParser(encoding="utf-8")
tree = ET.parse(file_path, parser=xmlp)
root = tree.getroot()
docs = []
for doc_id in range(len(root[0])):
    doc_segs = []
    doc = root[0][doc_id]
    for seg in doc.iter('seg'): 
        doc_segs.append(seg.text)
    docs.extend(doc_segs)
dev_texts = [re.sub(r'\s+', ' ', ''.join(d)).strip() for d in docs]
with open(data_path + 'dev_texts.txt', 'w', encoding='utf-8') as f:
    for text in dev_texts:
        f.write(text + '\n')
file_path = data_path + "IWSLT12.TED.MT.tst2012.en-fr.en.xml"

xmlp = ET.XMLParser(encoding="utf-8")
tree = ET.parse(file_path, parser=xmlp)
root = tree.getroot()

docs = []

for doc_id in range(len(root[0])):
    doc_segs = []
    doc = root[0][doc_id]
    for seg in doc.iter('seg'):
        doc_segs.append(seg.text)
    docs.extend(doc_segs)

test_texts_2012 = [re.sub(r'\s+', ' ', ''.join(d)).strip() for d in docs]

with open(data_path + 'test_texts_2012.txt', 'w', encoding='utf-8') as f:
    for text in test_texts_2012:
        f.write(text + '\n')
file_path = data_path + "train.tags.en-fr.en.xml"
with open(file_path) as f:
    xml = f.read()
tree = ET.fromstring("<root>"+ xml + "</root>")
docs = []
for doc in tree.iter('transcript'):
    text_arr=doc.text.split('\n')
    text_arr=[item.strip() for item in text_arr if(len(item.strip())>2)]
    # print(text_arr)
    docs.extend(text_arr)
    # break


train_texts=docs
with open(data_path + 'train_texts.txt', 'w', encoding='utf-8') as f:
    for text in train_texts:
        f.write(text + '\n')
# data fetch
with open(data_path + 'train_texts.txt', 'r', encoding='utf-8') as f:
    train_text = f.readlines()
with open(data_path + 'dev_texts.txt', 'r', encoding='utf-8') as f:
    valid_text = f.readlines()
with open(data_path + 'test_texts_2012.txt', 'r', encoding='utf-8') as f:
    test_text = f.readlines()
train_text[0]
datasets = train_text, valid_text, test_text
def clean_text(text):
    '''
    Text processing: replacing symbols with'','.',','as well as'?'one of
    '''
    text = text.replace('!', '.')
    text = text.replace(':', ',')
    text = text.replace('--', ',')
    
    reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    r = re.compile(reg, re.DOTALL)
    text = r.sub(' ', text)
    
    text = re.sub(r'\s-\s', ' , ', text)
    
#     text = text.replace('-', ',')
    text = text.replace(';', '.')
    text = text.replace(' ,', ',')
    text = text.replace('♫', '')
    text = text.replace('...', '')
    text = text.replace('.\"', ',')
    text = text.replace('"', ',')

    text = re.sub(r'--\s?--', '', text)
    text = re.sub(r'\s+', ' ', text)
    
    text = re.sub(r',\s?,', ',', text)
    text = re.sub(r',\s?\.', '.', text)
    text = re.sub(r'\?\s?\.', '?', text)
    text = re.sub(r'\s+', ' ', text)
    
    text = re.sub(r'\s+\?', '?', text)
    text = re.sub(r'\s+,', ',', text)
    text = re.sub(r'\.[\s+\.]+', '. ', text)
    text = re.sub(r'\s+\.', '.', text)
    
    return text.strip().lower()
datasets = [[clean_text(text) for text in ds] for ds in datasets]
  • Use the word segmentation tool of electra to segment words, and then construct the data set
model_name_or_path='electra-base'
tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)
punctuation_enc = {
        'O': '0',
        ',': '1',
        '.': '2',
        '?': '3',
    }
# Take a text sequence as an example to construct the data set required by the model
example_sentence="all the projections [ say that ] this one [ billion ] will [ only ] grow with one to two or three percent"

print('Use the example sentence to create the dataset:', example_sentence)

example_text=tokenizer.tokenize(example_sentence)
print(example_text)

label=[]
cur_text=[]
for item in example_text:
    if(item in punctuation_enc):
        print(item)
        label.pop()
        label.append(punctuation_enc[item])
    else:
        cur_text.append(item)
        label.append(punctuation_enc['O'])
# label=[item for item in text]
print(label)
print(cur_text)
print(len(label))
print(len(cur_text))
# Package into format according to the above construction process_ data
def format_data(train_text):
    '''
    Generate text according to the symbols appearing in the text tokens And corresponding labels
    return:
        texts: text tokens List, each item Is the corresponding to a text sample tokens list
        labels: Punctuation label list, each item Is a list of punctuation labels representing token Punctuation mark for the next position of
    '''
    labels=[]
    texts=[]
    for line in tqdm(train_text):
        line=line.strip()
        if(len(line)==2):
            print(line)
            continue
        text=tokenizer.tokenize(line)
        label=[]
        cur_text=[]
        flag=True
        for item in text:
            if(item in punctuation_enc):
                # print(item)
                if(len(label)>0):
                    label.pop()
                    label.append(punctuation_enc[item])
                else:
                    print(text)
                    falg=False
                    break
            else:
                cur_text.append(item)
                label.append(punctuation_enc['O'])
        if(flag):
            labels.append(label)
            texts.append(cur_text)
    return texts,labels
# Build training set
train_texts,labels=format_data(train_text)
print(len(train_texts))
print(train_texts[0])
print(labels[0])
def output_to_tsv(texts,labels,file_name):
    data=[]
    for text,label in zip(texts,labels):
        if(len(text)!=len(label)):
            print(text)
            print(label)
            continue
        data.append([' '.join(text),' '.join(label)])
    df=pd.DataFrame(data,columns=['text_a','label'])
    df.to_csv(file_name,index=False,sep='\t')
def output_to_train_tsv(texts,labels,file_name):
    data=[]
    for text,label in zip(texts,labels):
        if(len(text)!=len(label)):
            print(text)
            print(label)
            continue
        if(len(text)==0):
            continue
        data.append([' '.join(text),' '.join(label)])
    # data=data[65000:70000]
    df=pd.DataFrame(data,columns=['text_a','label'])
    df.to_csv(file_name,index=False,sep='\t')
output_to_train_tsv(train_texts,labels,'train.tsv')
test_texts,labels=format_data(test_text)
output_to_tsv(test_texts,labels,'test.tsv')
print(len(test_texts))
print(test_texts[0])
print(labels[0])

valid_texts,labels=format_data(valid_text)
output_to_tsv(valid_texts,labels,'dev.tsv')
print(len(valid_texts))
print(valid_texts[0])
print(labels[0])
raw_path='.'
train_file = os.path.join(raw_path, "train.tsv")
dev_file = os.path.join(raw_path, "dev.tsv")
train_data=pd.read_csv(train_file,sep='\t')
train_data.head()
# Number of test set samples after data cleaning
len(train_data)
def write_json(filename, dataset):
    print('write to'+filename)
    with codecs.open(filename, mode="w", encoding="utf-8") as f:
        ujson.dump(dataset, f)

3.3 building Dataset

class TEDTalk(DatasetBuilder):

    SPLITS = {
        'train': 'train.tsv',
        'dev':'dev.tsv',
        'test': 'test.tsv'
    }

    def _get_data(self, mode, **kwargs):
        default_root='.'
        self.mode=mode
        filename = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)

        return fullname

    def _read(self, filename, *args):
        df=pd.read_csv(filename,sep='\t')
        for idx,row in df.iterrows():
            text=row['text_a']
            if(type(text)==float):
                print(text)
                continue
            tokens=row['text_a'].split()
            tags=row['label'].split()
            yield {"tokens": tokens, "labels": tags}

    def get_labels(self):

        return ["0", "1", "2", "3"]
def load_dataset(path_or_read_func,
                 name=None,
                 data_files=None,
                 splits=None,
                 lazy=None,
                 **kwargs):
    
    reader_cls = TEDTalk
    print(reader_cls)
    if not name:
        reader_instance = reader_cls(lazy=lazy, **kwargs)
    else:
        reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)

    datasets = reader_instance.read_datasets(data_files=data_files, splits=splits)
    return datasets
def tokenize_and_align_labels(example, tokenizer, no_entity_id,
                              max_seq_len=512):
    labels = example['labels']
    example = example['tokens']
    # print(labels)
    tokenized_input = tokenizer(
        example,
        return_length=True,
        is_split_into_words=True,
        max_seq_len=max_seq_len)

    # -2 for [CLS] and [SEP]
    if len(tokenized_input['input_ids']) - 2 < len(labels):
        labels = labels[:len(tokenized_input['input_ids']) - 2]
    tokenized_input['labels'] = [no_entity_id] + labels + [no_entity_id]
    tokenized_input['labels'] += [no_entity_id] * (
        len(tokenized_input['input_ids']) - len(tokenized_input['labels']))
    # print(tokenized_input)
    return tokenized_input

Load dataset

# Create dataset, tokenizer and dataloader.
train_ds, test_ds = load_dataset('TEDTalk', splits=('train', 'test'), lazy=False)
label_list = train_ds.label_list
label_num = len(label_list)
# no_entity_id = label_num - 1
no_entity_id=0
print(label_list)

3.4 assembling Batch and Padding

batch_size=128
ignore_label = -100
max_seq_length=128
trans_func = partial(
        tokenize_and_align_labels,
        tokenizer=tokenizer,
        no_entity_id=no_entity_id,
        max_seq_len=max_seq_length)



batchify_fn = lambda samples, fn=Dict({
        'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'),  # input
        'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'),  # segment
        'seq_len': Stack(dtype='int64'),  # seq_len
        'labels': Pad(axis=0, pad_val=ignore_label, dtype='int64')  # label
    }): fn(samples)

train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

train_ds = train_ds.map(trans_func)

train_data_loader = DataLoader(
        dataset=train_ds,
        collate_fn=batchify_fn,
        num_workers=0,
        batch_sampler=train_batch_sampler,
        return_list=True)

test_ds = test_ds.map(trans_func)

test_data_loader = DataLoader(
        dataset=test_ds,
        collate_fn=batchify_fn,
        num_workers=0,
        batch_size=batch_size,
        return_list=True)
for index,data in enumerate(train_data_loader):
    # print(len(data))
    print(index)
    print(data)
    break

3.5 model configuration

device='gpu'
num_train_epochs=1
warmup_steps=0

model_name_or_path='electra-base'
max_steps=-1
learning_rate=5e-5
adam_epsilon=1e-8
weight_decay=0.0
paddle.set_device(device)

global_step = 0
logging_steps=200 # Save cycle of log
last_step = num_train_epochs * len(train_data_loader)
tic_train = time.time()
save_steps=200 # Model save cycle
output_dir='checkpoints/' # Model save directory

3.6 model construction

# Define the model netword and its loss
model = ElectraForTokenClassification.from_pretrained(model_name_or_path, num_classes=label_num)

3.6.1 setting up the AdamW optimizer

num_training_steps = max_steps if max_steps > 0 else len(
        train_data_loader) * num_train_epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
                                         warmup_steps)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=adam_epsilon,
        parameters=model.parameters(),
        weight_decay=weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)                                        

3.6.2 setting the crossentry loss function

loss_fct = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)

3.6.3 setting evaluation method

metric = paddle.metric.Accuracy()
def compute_metrics(labels, decodes, lens):
    decodes = [x for batch in decodes for x in batch]
    lens = [x for batch in lens for x in batch]
    labels=[x for batch in labels for x in batch]
    outputs = []
    nb_correct=0
    nb_true=0
    val_f1s=[]
    label_vals=[0,1,2,3]
    y_trues=[]
    y_preds=[]
    for idx, end in enumerate(lens):
        y_true = labels[idx][:end].tolist()
        y_pred = [x for x in decodes[idx][:end]]
        nb_correct += sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred))
        nb_true+=len(y_true)
        y_trues.extend(y_true)
        y_preds.extend(y_pred)

    score = nb_correct / nb_true
    # val_f1 = metrics.f1_score(y_trues, y_preds, average='micro', labels=label_vals)

    result=classification_report(y_trues, y_preds)
    # print(val_f1)   
    return score,result

3.7 model training

def evaluate(model, loss_fct, data_loader, label_num):
    model.eval()
    pred_list = []
    len_list = []
    labels_list=[]
    for batch in data_loader:
        input_ids, token_type_ids, length, labels = batch
        logits = model(input_ids, token_type_ids)
        loss = loss_fct(logits, labels)
        avg_loss = paddle.mean(loss)
        pred = paddle.argmax(logits, axis=-1)
        pred_list.append(pred.numpy())
        len_list.append(length.numpy())
        labels_list.append(labels.numpy())
    accuracy,result=compute_metrics(labels_list, pred_list, len_list)
    print("eval loss: %f, accuracy: %f" % (avg_loss, accuracy))
    print(result)
    model.train()
# evaluate(model, loss_fct, metric, test_data_loader,label_num)
for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_data_loader):
        global_step += 1
        input_ids, token_type_ids, _, labels = batch
        logits = model(input_ids, token_type_ids)
        loss = loss_fct(logits, labels)
        avg_loss = paddle.mean(loss)
        if global_step % logging_steps == 0:
            print("global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_step, epoch, step, avg_loss,
                       logging_steps / (time.time() - tic_train)))
            tic_train = time.time()
        avg_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()
        if global_step % save_steps == 0 or global_step == last_step:
            if paddle.distributed.get_rank() == 0:
                    evaluate(model, loss_fct, test_data_loader,
                                label_num)
                    paddle.save(model.state_dict(),os.path.join(output_dir,
                                                "model_%d.pdparams" % global_step))

3.8 model saving

paddle.save(model.state_dict(),os.path.join(output_dir,
                                                "model_final.pdparams"))

3.9 model prediction

3.9.1 loading the trained model

init_checkpoint_path=os.path.join(output_dir,'model_final.pdparams')
model_dict = paddle.load(init_checkpoint_path)
model.set_dict(model_dict)
punctuation_dec = {
        '0': 'O',
        '1': ',',
        '2': '.',
        '3': '?',
    }

3.9.2 prediction output

def parse_decodes(input_words, id2label, decodes, lens):
    decodes = [x for batch in decodes for x in batch]
    lens = [x for batch in lens for x in batch]

    outputs = []
    for idx, end in enumerate(lens):
        sent = input_words[idx]['tokens']
        tags = [id2label[x] for x in decodes[idx][1:end]]
        sent_out = []
        tags_out = []
        for s, t in zip(sent, tags):
            if(t=='0'):
                sent_out.append(s)
            else:
                # sent_out.append(s)
                sent_out.append(s+punctuation_dec[t])
        sent=' '.join(sent_out)
        sent=sent.replace(' ##','')
        outputs.append(sent)
    return outputs
id2label = dict(enumerate(test_ds.label_list))
raw_data = test_ds.data
model.eval()
pred_list = []
len_list = []
for step, batch in enumerate(test_data_loader):
    input_ids, token_type_ids, length, labels = batch
    logits = model(input_ids, token_type_ids)
    pred = paddle.argmax(logits, axis=-1)
    pred_list.append(pred.numpy())
    len_list.append(length.numpy())
preds = parse_decodes(raw_data, id2label, pred_list, len_list)

3.9.3 write to file

file_path = "results.txt"
with open(file_path, "w", encoding="utf8") as fout:
    fout.write("\n".join(preds))
    # Print some examples
print("The results have been saved in the file: %s, some examples are shown below: " % file_path)
print("\n".join(preds[:5]))

4. More paddledu information

1. Paddledu one-stop in-depth learning online encyclopedia awesome-DeepLearning There are other abilities in, so please look forward to:

  • Introduction to deep learning
  • Deep learning questions
  • Characteristic course
  • Industrial practice

If you have any questions during the use of paddledu, you are welcome to awesome-DeepLearning For more in-depth learning materials, please refer to Propeller deep learning platform.

Remember to order one Star ⭐ Collection oh~~

2. PaddleEdu technical exchange group (QQ)

At present, 2000 + students in QQ group have studied together. Welcome to join us by scanning the code

Tags: NLP paddlepaddle

Posted on Thu, 11 Nov 2021 23:08:07 -0500 by dico