Optimization of Bert as service

1, Service deployment

PyZMQ is implemented in Python using zeroMQ, which provides a lightweight and fast messaging implementation. A simple example of using C/S messaging is as follows:

import zmq
import zmq.decorators as zmqd

@zmqd.socket(zmq.PUSH)
def send(sock):
    sock.bind('tcp://*:5555')
    sock.send(b'hello')
 
# in another process   
@zmqd.socket(zmq.PULL)
def recv(sock):
    sock.connect('tcp://localhost:5555')
    print(sock.recv())  # shows b'hello'

2, Service acceleration

The overall architecture of service acceleration is as follows:

  • freezed: convert dynamic graph to static graph and variables to constants, i.e. tf.variable -- > tf.constant
  • Pruned: delete all unused nodes and edges in the graph
  • Quantified: convert tf.float32 to tf.float16 or tf.unit8
    tensorflow provides APIs for freezing and pruning. You only need to define the input and output nodes, such as:
input_tensors = [input_ids, input_mask, input_type_ids]
output_tensors = [pooled]
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.graph_util import convert_variables_to_constants

# get graph
tmp_g = tf.get_default_graph().as_graph_def()

sess = tf.Session()
# load parameters then freeze
sess.run(tf.global_variables_initializer())
tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])

# pruning
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(tmp_g, [n.name[:-2] for n in input_tensors],
    [n.name[:-2] for n in output_tensors],
    [dtype.as_datatype_enum for dtype in dtypes], False)
    
with tf.gfile.GFile('optimized.graph', 'wb') as f:
    f.write(tmp_g.SerializeToString())

3, Reduce service delay

Ensure one initialization

def input_fn_builder(sock):
    def gen():
        while True:
            # receive request
            client_id, raw_msg = sock.recv_multipart()
            msg = jsonapi.loads(raw_msg)
            tmp_f = convert_lst_to_features(msg)
            yield {'client_id': client_id,
                   'input_ids': [f.input_ids for f in tmp_f],
                   'input_mask': [f.input_mask for f in tmp_f],
                   'input_type_ids': [f.input_type_ids for f in tmp_f]}

    def input_fn():
        return (tf.data.Dataset.from_generator(gen,
            output_types={'input_ids': tf.int32, 'input_mask': tf.int32, 'input_type_ids': tf.int32, 'client_id': tf.string},
            output_shapes={'client_id': (), 'input_ids': (None, max_seq_len), 'input_mask': (None, max_seq_len),'input_type_ids': (None, max_seq_len)})
                .prefetch(10))
    return input_fn
# initialize BERT model once
estimator = Estimator(model_fn=bert_model_fn)
# keep listen and predict
for result in estimator.predict(input_fn_builder(client), yield_single_examples=False):
    send_back(result)

If there is a GPU, the prefetch(10) operation can be accelerated by 10%

4, Improve service scalability

Suppose multiple clients send requests to the server at the same time. What must be considered before ensuring parallel computing is how the server should handle the reception? If it receives the first request, keep the connection until it sends back the result; Then continue with the second request? What happens if there are 100 customers? Should the server use the same logic to manage 100 connections?
Consider another scenario where a client sends 10K sentences every 10 milliseconds. The server parallelizes the work into subtasks and assigns them to multiple GPU staff. Then another client joins and sends a sentence per second. This small batch client should theoretically get results immediately. Unfortunately, because all GPU staff are busy calculating and receiving for the first client, the second client will never get a time period until the server completes 100 batches (10K sentences per batch) from the first client.

Scalability and load balancing issues arise when multiple clients connect to a server. In the BERT as service, a vendor worker sink pipeline of push/pull and publish/subscribe sockets is implemented. The function of a ventilator is similar to that of a batch scheduler and a load balancer. It divides large requests from clients into mini jobs. The load of these mini jobs is balanced before sending them to the worker. The worker receives the mini job from the inventor and performs the actual BERT reasoning, and finally sends the result to the sink. The sink collects the output of all workers' mini jobs. It checks the integrity of all requests from the vendor and publishes the complete results to the client. The overall structure is shown in the figure below:

Original link Serving Google BERT in Production using Tensorflow and ZeroMQ

Tags: Machine Learning AI Deep Learning NLP BERT

Posted on Wed, 20 Oct 2021 16:19:34 -0400 by spasm37