Punctuation prediction based on ELECTRA
1. Resources
⭐ ⭐ ⭐ Welcome to order a small one Star support! ⭐ ⭐ ⭐
Open source is not easy. I hope you can support it~

-
For more Transformer Models (BERT, ERNIE, ViT, DeiT, swing transformer, etc.) and in-depth learning materials in CV and NLP, please refer to: awesome-DeepLearning
-
For more NLP models (BERT Series), please refer to: PaddleNLP
2. Principle interpretation
2.1 introduction
Masked language model (MLM), similar to BERT, uses [MASK] to replace some characters in the text through pre training method, which destroys the original input of the text, and then trains the model to reconstruct the original text. Although they produce good results in downstream NLP tasks, they usually require a lot of computation to be effective. As an alternative, the author proposes a more effective pre training task, called Replaced Token Detection(RTD), character replacement detection. RTD method is not to MASK the input, but to destroy the input by using the generation network to generate some reasonable replacement characters. Then, we train a discriminator model, which can predict whether the current character has been replaced by the language model. The experimental results show that this new pre training task is more effective than MLM, because the task is defined on all text inputs, not only a small part covered up. When the model size, data and computing power are the same, the context representation learned by RTD method is much better than that learned by BERT.
In the above figure, the figure on the left is the enlarged version on the right, the vertical axis is dev GLUE fraction, the horizontal axis is FLOPs(floating point operations), and the calculation amount statistics of floating point numbers provided in Tensorflow. It can be seen from the above figure that ELECTRA of the same magnitude has been rolling BERT, and the training steps are longer, reaching the effect of the SOTA model RoBERTa at that time. As can be seen from the graph on the left, there is still room for ELECTRA effect to continue to rise.
2.2 model structure
ELECTRA's greatest contribution is to propose a new pre training task and framework, which is also mentioned in the above introduction. The generative MLM pre training task is changed into a discriminant RTD task, and then judge whether the current token has been replaced. So here's the problem. Suppose I randomly replace some words in the input, and then let BERT predict whether it has been replaced. Is this feasible? Some people have done experiments, but the effect is not very good, because random replacement is too simple.
The author uses a MLM G-BERT to transform the input sentence, and then throws it to D-BERT to judge which word has been modified, as follows:
In downstream tasks, the Discriminator part is generally used for fine tune.
In the following chapters, we will introduce the principle and code implementation of each part of ELECTRA in detail. ELECTRA includes embedding, ELECTRAModel, Generator and Discriminator. Embedding is the word embedding module of the whole model
2.2.1 ELECTRA model
from paddlenlp.transformers import PretrainedModel from paddlenlp.transformers import ElectraPretrainedModel import paddle import paddle.nn as nn import paddle.tensor as tensor import paddle.nn.functional as F
Before implementing the ELECTRA module, we need to implement the Embedding module, which includes Input Embedding, segment Embedding and position Embedding
class ElectraEmbeddings(nn.Layer): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, vocab_size, embedding_size, hidden_dropout_prob, max_position_embeddings, type_vocab_size): super(ElectraEmbeddings, self).__init__() self.word_embeddings = nn.Embedding(vocab_size, embedding_size) self.position_embeddings = nn.Embedding(max_position_embeddings, embedding_size) self.token_type_embeddings = nn.Embedding(type_vocab_size, embedding_size) self.layer_norm = nn.LayerNorm(embedding_size, epsilon=1e-12) self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None, position_ids=None): if position_ids is None: ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) position_ids = seq_length - ones position_ids.stop_gradient = True if token_type_ids is None: token_type_ids = paddle.zeros_like(input_ids, dtype="int64") # input embedding input_embeddings = self.word_embeddings(input_ids) # position embedding position_embeddings = self.position_embeddings(position_ids) # segment embedding token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = input_embeddings + position_embeddings + token_type_embeddings embeddings = self.layer_norm(embeddings) embeddings = self.dropout(embeddings) return embeddings
The next step is to implement the ELECTRA model, which is essentially a Transformer Encoder, which is no different from the structure of BERT
class ElectraModel(ElectraPretrainedModel): def __init__(self, vocab_size, embedding_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, pad_token_id): super(ElectraModel, self).__init__() self.pad_token_id = pad_token_id self.initializer_range = initializer_range self.embeddings = ElectraEmbeddings( vocab_size, embedding_size, hidden_dropout_prob, max_position_embeddings, type_vocab_size) if embedding_size != hidden_size: self.embeddings_project = nn.Linear(embedding_size, hidden_size) encoder_layer = nn.TransformerEncoderLayer( hidden_size, num_attention_heads, intermediate_size, dropout=hidden_dropout_prob, activation=hidden_act, attn_dropout=attention_probs_dropout_prob, act_dropout=0) self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers) self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id ).astype(paddle.get_default_dtype()) * -1e9, axis=[1, 2]) # Embedding embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids) if hasattr(self, "embeddings_project"): embedding_output = self.embeddings_project(embedding_output) # Transformer Encoder encoder_outputs = self.encoder(embedding_output, attention_mask) return encoder_outputs
2.2.2 Generator
Next is the Generator part of Electra. You need to load the previously implemented Electra, and then access a full connection layer later.
class ElectraGeneratorPredictions(nn.Layer): """Prediction layer for the generator, made up of two dense layers.""" def __init__(self, embedding_size, hidden_size, hidden_act): super(ElectraGeneratorPredictions, self).__init__() self.layer_norm = nn.LayerNorm(embedding_size) self.dense = nn.Linear(hidden_size, embedding_size) self.act = get_activation(hidden_act) def forward(self, generator_hidden_states): hidden_states = self.dense(generator_hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states
class ElectraGenerator(ElectraPretrainedModel): def __init__(self, electra): super(ElectraGenerator, self).__init__() self.electra = electra self.generator_predictions = ElectraGeneratorPredictions( self.electra.config["embedding_size"], self.electra.config["hidden_size"], self.electra.config["hidden_act"]) if not self.tie_word_embeddings: self.generator_lm_head = nn.Linear( self.electra.config["embedding_size"], self.electra.config["vocab_size"]) else: self.generator_lm_head_bias = paddle.fluid.layers.create_parameter( shape=[self.electra.config["vocab_size"]], dtype=paddle.get_default_dtype(), is_bias=True) self.init_weights() def get_input_embeddings(self): return self.electra.embeddings.word_embeddings def forward(self, input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None): generator_sequence_output = self.electra(input_ids, token_type_ids, position_ids, attention_mask) prediction_scores = self.generator_predictions( generator_sequence_output) if not self.tie_word_embeddings: prediction_scores = self.generator_lm_head(prediction_scores) else: prediction_scores = paddle.add(paddle.matmul( prediction_scores, self.get_input_embeddings().weight, transpose_y=True), self.generator_lm_head_bias) return prediction_scores
2.2.3 Discriminator
The next step is the implementation of the discriminator, which is also connected to a full connection layer behind ELECTRA.
class ElectraDiscriminatorPredictions(nn.Layer): """Prediction layer for the discriminator, made up of two dense layers.""" def __init__(self, hidden_size, hidden_act): super(ElectraDiscriminatorPredictions, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dense_prediction = nn.Linear(hidden_size, 1) self.act = get_activation(hidden_act) def forward(self, discriminator_hidden_states): hidden_states = self.dense(discriminator_hidden_states) hidden_states = self.act(hidden_states) logits = self.dense_prediction(hidden_states).squeeze() return logits
class ElectraDiscriminator(ElectraPretrainedModel): def __init__(self, electra): super(ElectraDiscriminator, self).__init__() self.electra = electra self.discriminator_predictions = ElectraDiscriminatorPredictions( self.electra.config["hidden_size"], self.electra.config["hidden_act"]) self.init_weights() def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): discriminator_sequence_output = self.electra( input_ids, token_type_ids, position_ids, attention_mask) logits = self.discriminator_predictions(discriminator_sequence_output) return logits
2.2.4 pre training loss function
Finally, the loss function of ELECTRA is trained. The loss function includes the loss of genrator and the loss of discriminator
class ElectraPretrainingCriterion(nn.Layer): def __init__(self, vocab_size, gen_weight, disc_weight): super(ElectraPretrainingCriterion, self).__init__() self.vocab_size = vocab_size self.gen_weight = gen_weight self.disc_weight = disc_weight self.gen_loss_fct = nn.CrossEntropyLoss(reduction='none') self.disc_loss_fct = nn.BCEWithLogitsLoss(reduction='none') def forward(self, generator_prediction_scores, discriminator_prediction_scores, generator_labels, discriminator_labels, attention_mask): # generator loss gen_loss = self.gen_loss_fct( paddle.reshape(generator_prediction_scores, [-1, self.vocab_size]), paddle.reshape(generator_labels, [-1])) # todo: we can remove 4 lines after when CrossEntropyLoss(reduction='mean') improved umask_positions = paddle.zeros_like(generator_labels).astype( paddle.get_default_dtype()) mask_positions = paddle.ones_like(generator_labels).astype( paddle.get_default_dtype()) mask_positions = paddle.where(generator_labels == -100, umask_positions, mask_positions) if mask_positions.sum() == 0: gen_loss = paddle.to_tensor([0.0]) else: gen_loss = gen_loss.sum() / mask_positions.sum() # discriminator loss seq_length = discriminator_labels.shape[1] disc_loss = self.disc_loss_fct( paddle.reshape(discriminator_prediction_scores, [-1, seq_length]), discriminator_labels.astype(paddle.get_default_dtype())) if attention_mask is not None: umask_positions = paddle.ones_like(discriminator_labels).astype( paddle.get_default_dtype()) mask_positions = paddle.zeros_like(discriminator_labels).astype( paddle.get_default_dtype()) use_disc_loss = paddle.where(attention_mask, disc_loss, mask_positions) umask_positions = paddle.where(attention_mask, umask_positions, mask_positions) disc_loss = use_disc_loss.sum() / umask_positions.sum() else: total_positions = paddle.ones_like(discriminator_labels).astype( paddle.get_default_dtype()) disc_loss = disc_loss.sum() / total_positions.sum() return self.gen_weight * gen_loss + self.disc_weight * disc_loss
2.3 punctuation prediction task
In the past, this research direction was called "sense boundary detection", and now it is called predict impulse or impulse restoration. Punctuation prediction using deep learning can be divided into three categories:
- Method based on acoustic features. This kind of method predicts punctuation based on people's pause while speaking, but in the real ASR system, if there is an unnatural pause, its prediction ability will decline.
- Method based on text features. Text data often have different types. For example, the model trained on the people's daily data set is difficult to work on some chat data sets.
- A method combining text and acoustic features. This kind of method is effective, but the data set is required to have voice data and ASR transcripts at the same time, so the data is difficult to obtain.
In this experiment, Discriminator is used to do punctuation prediction task. Punctuation prediction is essentially a sequence annotation task. There are three kinds of punctuation marks predicted in this experiment: comma, period and question mark. Other types of punctuation can also be added if the reader is interested.
3. Code practice
3.1 import library package
import os import xml.etree.ElementTree as ET import codecs from collections import Counter import re import ujson import pandas as pd import random import time import math from functools import partial import inspect from tqdm import tqdm import collections import numpy as np import paddle from paddle.io import DataLoader from paddle.dataset.common import md5file import paddlenlp as ppnlp from paddlenlp.transformers import LinearDecayWithWarmup from paddlenlp.transformers import ElectraForTokenClassification, ElectraTokenizer from paddlenlp.data import Stack, Tuple, Pad, Dict from paddlenlp.datasets import DatasetBuilder from paddlenlp.utils.env import DATA_HOME from sklearn import metrics # https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics from sklearn.metrics import classification_report
3.2 data preprocessing
Because the data set is in xml format after decompression, it needs to be extracted and then segmented to construct the data set.
The data set used this time is IWSLT12, which comes from http://hltc.cs.ust.hk/iwslt/index.html . We can also find it in the Aistudio public dataset.
!unzip -o IWSLT12.zip -d IWSLT12
data_path = "IWSLT12/" file_path = data_path + "IWSLT12.TALK.dev2010.en-fr.en.xml" xmlp = ET.XMLParser(encoding="utf-8") tree = ET.parse(file_path, parser=xmlp) root = tree.getroot()
docs = [] for doc_id in range(len(root[0])): doc_segs = [] doc = root[0][doc_id] for seg in doc.iter('seg'): doc_segs.append(seg.text) docs.extend(doc_segs)
dev_texts = [re.sub(r'\s+', ' ', ''.join(d)).strip() for d in docs] with open(data_path + 'dev_texts.txt', 'w', encoding='utf-8') as f: for text in dev_texts: f.write(text + '\n')
file_path = data_path + "IWSLT12.TED.MT.tst2012.en-fr.en.xml" xmlp = ET.XMLParser(encoding="utf-8") tree = ET.parse(file_path, parser=xmlp) root = tree.getroot() docs = [] for doc_id in range(len(root[0])): doc_segs = [] doc = root[0][doc_id] for seg in doc.iter('seg'): doc_segs.append(seg.text) docs.extend(doc_segs) test_texts_2012 = [re.sub(r'\s+', ' ', ''.join(d)).strip() for d in docs] with open(data_path + 'test_texts_2012.txt', 'w', encoding='utf-8') as f: for text in test_texts_2012: f.write(text + '\n')
file_path = data_path + "train.tags.en-fr.en.xml" with open(file_path) as f: xml = f.read() tree = ET.fromstring("<root>"+ xml + "</root>")
docs = [] for doc in tree.iter('transcript'): text_arr=doc.text.split('\n') text_arr=[item.strip() for item in text_arr if(len(item.strip())>2)] # print(text_arr) docs.extend(text_arr) # break
train_texts=docs with open(data_path + 'train_texts.txt', 'w', encoding='utf-8') as f: for text in train_texts: f.write(text + '\n')
# data fetch with open(data_path + 'train_texts.txt', 'r', encoding='utf-8') as f: train_text = f.readlines() with open(data_path + 'dev_texts.txt', 'r', encoding='utf-8') as f: valid_text = f.readlines() with open(data_path + 'test_texts_2012.txt', 'r', encoding='utf-8') as f: test_text = f.readlines()
train_text[0]
datasets = train_text, valid_text, test_text
def clean_text(text): ''' Text processing: replacing symbols with'','.',','as well as'?'one of ''' text = text.replace('!', '.') text = text.replace(':', ',') text = text.replace('--', ',') reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})" r = re.compile(reg, re.DOTALL) text = r.sub(' ', text) text = re.sub(r'\s-\s', ' , ', text) # text = text.replace('-', ',') text = text.replace(';', '.') text = text.replace(' ,', ',') text = text.replace('♫', '') text = text.replace('...', '') text = text.replace('.\"', ',') text = text.replace('"', ',') text = re.sub(r'--\s?--', '', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r',\s?,', ',', text) text = re.sub(r',\s?\.', '.', text) text = re.sub(r'\?\s?\.', '?', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'\s+\?', '?', text) text = re.sub(r'\s+,', ',', text) text = re.sub(r'\.[\s+\.]+', '. ', text) text = re.sub(r'\s+\.', '.', text) return text.strip().lower()
datasets = [[clean_text(text) for text in ds] for ds in datasets]
- Use the word segmentation tool of electra to segment words, and then construct the data set
model_name_or_path='electra-base' tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)
punctuation_enc = { 'O': '0', ',': '1', '.': '2', '?': '3', }
# Take a text sequence as an example to construct the data set required by the model example_sentence="all the projections [ say that ] this one [ billion ] will [ only ] grow with one to two or three percent" print('Use the example sentence to create the dataset:', example_sentence) example_text=tokenizer.tokenize(example_sentence) print(example_text) label=[] cur_text=[] for item in example_text: if(item in punctuation_enc): print(item) label.pop() label.append(punctuation_enc[item]) else: cur_text.append(item) label.append(punctuation_enc['O']) # label=[item for item in text] print(label) print(cur_text) print(len(label)) print(len(cur_text))
# Package into format according to the above construction process_ data def format_data(train_text): ''' Generate text according to the symbols appearing in the text tokens And corresponding labels return: texts: text tokens List, each item Is the corresponding to a text sample tokens list labels: Punctuation label list, each item Is a list of punctuation labels representing token Punctuation mark for the next position of ''' labels=[] texts=[] for line in tqdm(train_text): line=line.strip() if(len(line)==2): print(line) continue text=tokenizer.tokenize(line) label=[] cur_text=[] flag=True for item in text: if(item in punctuation_enc): # print(item) if(len(label)>0): label.pop() label.append(punctuation_enc[item]) else: print(text) falg=False break else: cur_text.append(item) label.append(punctuation_enc['O']) if(flag): labels.append(label) texts.append(cur_text) return texts,labels
# Build training set train_texts,labels=format_data(train_text)
print(len(train_texts)) print(train_texts[0]) print(labels[0])
def output_to_tsv(texts,labels,file_name): data=[] for text,label in zip(texts,labels): if(len(text)!=len(label)): print(text) print(label) continue data.append([' '.join(text),' '.join(label)]) df=pd.DataFrame(data,columns=['text_a','label']) df.to_csv(file_name,index=False,sep='\t')
def output_to_train_tsv(texts,labels,file_name): data=[] for text,label in zip(texts,labels): if(len(text)!=len(label)): print(text) print(label) continue if(len(text)==0): continue data.append([' '.join(text),' '.join(label)]) # data=data[65000:70000] df=pd.DataFrame(data,columns=['text_a','label']) df.to_csv(file_name,index=False,sep='\t')
output_to_train_tsv(train_texts,labels,'train.tsv')
test_texts,labels=format_data(test_text) output_to_tsv(test_texts,labels,'test.tsv') print(len(test_texts)) print(test_texts[0]) print(labels[0]) valid_texts,labels=format_data(valid_text) output_to_tsv(valid_texts,labels,'dev.tsv') print(len(valid_texts)) print(valid_texts[0]) print(labels[0])
raw_path='.' train_file = os.path.join(raw_path, "train.tsv") dev_file = os.path.join(raw_path, "dev.tsv") train_data=pd.read_csv(train_file,sep='\t') train_data.head()
# Number of test set samples after data cleaning len(train_data)
def write_json(filename, dataset): print('write to'+filename) with codecs.open(filename, mode="w", encoding="utf-8") as f: ujson.dump(dataset, f)
3.3 building Dataset
class TEDTalk(DatasetBuilder): SPLITS = { 'train': 'train.tsv', 'dev':'dev.tsv', 'test': 'test.tsv' } def _get_data(self, mode, **kwargs): default_root='.' self.mode=mode filename = self.SPLITS[mode] fullname = os.path.join(default_root, filename) return fullname def _read(self, filename, *args): df=pd.read_csv(filename,sep='\t') for idx,row in df.iterrows(): text=row['text_a'] if(type(text)==float): print(text) continue tokens=row['text_a'].split() tags=row['label'].split() yield {"tokens": tokens, "labels": tags} def get_labels(self): return ["0", "1", "2", "3"]
def load_dataset(path_or_read_func, name=None, data_files=None, splits=None, lazy=None, **kwargs): reader_cls = TEDTalk print(reader_cls) if not name: reader_instance = reader_cls(lazy=lazy, **kwargs) else: reader_instance = reader_cls(lazy=lazy, name=name, **kwargs) datasets = reader_instance.read_datasets(data_files=data_files, splits=splits) return datasets
def tokenize_and_align_labels(example, tokenizer, no_entity_id, max_seq_len=512): labels = example['labels'] example = example['tokens'] # print(labels) tokenized_input = tokenizer( example, return_length=True, is_split_into_words=True, max_seq_len=max_seq_len) # -2 for [CLS] and [SEP] if len(tokenized_input['input_ids']) - 2 < len(labels): labels = labels[:len(tokenized_input['input_ids']) - 2] tokenized_input['labels'] = [no_entity_id] + labels + [no_entity_id] tokenized_input['labels'] += [no_entity_id] * ( len(tokenized_input['input_ids']) - len(tokenized_input['labels'])) # print(tokenized_input) return tokenized_input
Load dataset
# Create dataset, tokenizer and dataloader. train_ds, test_ds = load_dataset('TEDTalk', splits=('train', 'test'), lazy=False)
label_list = train_ds.label_list label_num = len(label_list) # no_entity_id = label_num - 1 no_entity_id=0 print(label_list)
3.4 assembling Batch and Padding
batch_size=128 ignore_label = -100 max_seq_length=128
trans_func = partial( tokenize_and_align_labels, tokenizer=tokenizer, no_entity_id=no_entity_id, max_seq_len=max_seq_length) batchify_fn = lambda samples, fn=Dict({ 'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'), # input 'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'), # segment 'seq_len': Stack(dtype='int64'), # seq_len 'labels': Pad(axis=0, pad_val=ignore_label, dtype='int64') # label }): fn(samples) train_batch_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=batch_size, shuffle=True, drop_last=True) train_ds = train_ds.map(trans_func) train_data_loader = DataLoader( dataset=train_ds, collate_fn=batchify_fn, num_workers=0, batch_sampler=train_batch_sampler, return_list=True) test_ds = test_ds.map(trans_func) test_data_loader = DataLoader( dataset=test_ds, collate_fn=batchify_fn, num_workers=0, batch_size=batch_size, return_list=True)
for index,data in enumerate(train_data_loader): # print(len(data)) print(index) print(data) break
3.5 model configuration
device='gpu' num_train_epochs=1 warmup_steps=0 model_name_or_path='electra-base' max_steps=-1 learning_rate=5e-5 adam_epsilon=1e-8 weight_decay=0.0 paddle.set_device(device) global_step = 0 logging_steps=200 # Save cycle of log last_step = num_train_epochs * len(train_data_loader) tic_train = time.time() save_steps=200 # Model save cycle output_dir='checkpoints/' # Model save directory
3.6 model construction
# Define the model netword and its loss model = ElectraForTokenClassification.from_pretrained(model_name_or_path, num_classes=label_num)
3.6.1 setting up the AdamW optimizer
num_training_steps = max_steps if max_steps > 0 else len( train_data_loader) * num_train_epochs lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_steps) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, epsilon=adam_epsilon, parameters=model.parameters(), weight_decay=weight_decay, apply_decay_param_fun=lambda x: x in decay_params)
3.6.2 setting the crossentry loss function
loss_fct = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
3.6.3 setting evaluation method
metric = paddle.metric.Accuracy()
def compute_metrics(labels, decodes, lens): decodes = [x for batch in decodes for x in batch] lens = [x for batch in lens for x in batch] labels=[x for batch in labels for x in batch] outputs = [] nb_correct=0 nb_true=0 val_f1s=[] label_vals=[0,1,2,3] y_trues=[] y_preds=[] for idx, end in enumerate(lens): y_true = labels[idx][:end].tolist() y_pred = [x for x in decodes[idx][:end]] nb_correct += sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred)) nb_true+=len(y_true) y_trues.extend(y_true) y_preds.extend(y_pred) score = nb_correct / nb_true # val_f1 = metrics.f1_score(y_trues, y_preds, average='micro', labels=label_vals) result=classification_report(y_trues, y_preds) # print(val_f1) return score,result
3.7 model training
def evaluate(model, loss_fct, data_loader, label_num): model.eval() pred_list = [] len_list = [] labels_list=[] for batch in data_loader: input_ids, token_type_ids, length, labels = batch logits = model(input_ids, token_type_ids) loss = loss_fct(logits, labels) avg_loss = paddle.mean(loss) pred = paddle.argmax(logits, axis=-1) pred_list.append(pred.numpy()) len_list.append(length.numpy()) labels_list.append(labels.numpy()) accuracy,result=compute_metrics(labels_list, pred_list, len_list) print("eval loss: %f, accuracy: %f" % (avg_loss, accuracy)) print(result) model.train()
# evaluate(model, loss_fct, metric, test_data_loader,label_num)
for epoch in range(num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, token_type_ids, _, labels = batch logits = model(input_ids, token_type_ids) loss = loss_fct(logits, labels) avg_loss = paddle.mean(loss) if global_step % logging_steps == 0: print("global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" % (global_step, epoch, step, avg_loss, logging_steps / (time.time() - tic_train))) tic_train = time.time() avg_loss.backward() optimizer.step() lr_scheduler.step() optimizer.clear_grad() if global_step % save_steps == 0 or global_step == last_step: if paddle.distributed.get_rank() == 0: evaluate(model, loss_fct, test_data_loader, label_num) paddle.save(model.state_dict(),os.path.join(output_dir, "model_%d.pdparams" % global_step))
3.8 model saving
paddle.save(model.state_dict(),os.path.join(output_dir, "model_final.pdparams"))
3.9 model prediction
3.9.1 loading the trained model
init_checkpoint_path=os.path.join(output_dir,'model_final.pdparams') model_dict = paddle.load(init_checkpoint_path) model.set_dict(model_dict)
punctuation_dec = { '0': 'O', '1': ',', '2': '.', '3': '?', }
3.9.2 prediction output
def parse_decodes(input_words, id2label, decodes, lens): decodes = [x for batch in decodes for x in batch] lens = [x for batch in lens for x in batch] outputs = [] for idx, end in enumerate(lens): sent = input_words[idx]['tokens'] tags = [id2label[x] for x in decodes[idx][1:end]] sent_out = [] tags_out = [] for s, t in zip(sent, tags): if(t=='0'): sent_out.append(s) else: # sent_out.append(s) sent_out.append(s+punctuation_dec[t]) sent=' '.join(sent_out) sent=sent.replace(' ##','') outputs.append(sent) return outputs
id2label = dict(enumerate(test_ds.label_list)) raw_data = test_ds.data
model.eval() pred_list = [] len_list = [] for step, batch in enumerate(test_data_loader): input_ids, token_type_ids, length, labels = batch logits = model(input_ids, token_type_ids) pred = paddle.argmax(logits, axis=-1) pred_list.append(pred.numpy()) len_list.append(length.numpy()) preds = parse_decodes(raw_data, id2label, pred_list, len_list)
3.9.3 write to file
file_path = "results.txt" with open(file_path, "w", encoding="utf8") as fout: fout.write("\n".join(preds)) # Print some examples print("The results have been saved in the file: %s, some examples are shown below: " % file_path) print("\n".join(preds[:5]))
4. More paddledu information
1. Paddledu one-stop in-depth learning online encyclopedia awesome-DeepLearning There are other abilities in, so please look forward to:
- Introduction to deep learning

- Deep learning questions

- Characteristic course

- Industrial practice

If you have any questions during the use of paddledu, you are welcome to awesome-DeepLearning For more in-depth learning materials, please refer to Propeller deep learning platform.
Remember to order one Star ⭐ Collection oh~~
2. PaddleEdu technical exchange group (QQ)
At present, 2000 + students in QQ group have studied together. Welcome to join us by scanning the code