Use of the DeepJavaLibrary(DJL) framework: AlphaPose on java

First, here is the complete project code

Implementing the environment

  • java 1.8
  • djl 1.12.0
  • opencv-java 4.5.1(maven has javacv installed, manual configuration is not considered for the moment)

Implementation Principle

What the DJL framework can do for us

The DJL framework provides us with adaptions to implement multiple inference engines in java, so we can easily complete inference tasks in Java after exporting models in libtorch, onnx, mxnet and other formats.
DJL cannot perform complex post-computational tasks for us, so we may need to use the opencv framework or the built-in NDArray objects in DJL to perform calculations beyond reasoning.

Alphapose

Alphapose's inference model consists of two main parts

  1. yolo target detector
  2. alphapose's own single-person posture estimation network (SPPE)
    In the middle of the two models, alphapose uses an affine transformation to scale the yolo detection results and a corresponding inverse operation to restore the coordinate system of the SPPE output.

Implementation Steps

1. Exporting the alphapose model

Export yolov5

There are many ways to export yolov5 on the web. You can refer to my blog

Exporting a single-person posture estimation network

2. Implementing Translator

Using YoloTranslator

You can refer to my blog (or just that one) 😅)

Implement SPPETranslator

Translator is a reasoning template provided to us by DJL. We can override the methods in the template to complete data preprocessing and subsequent processing beyond reasoning.

1. We define SPPETranslator

public class SPPETranslator extends BasePairTranslator<Mat, Rectangle, Joints> {
...
}

BasePairTranslator <I, P, O>is a class I encapsulate myself. I represents the input type, O represents the output type, P is because our input contains two bounding boxes, the original image and the yolo output.

2. Data Preprocessing

Here, ctx.getAttachmen is used to save a new dimension from the affine transformation, which is then used to restore the coordinates of the response.
The reason cropped_bboxes uses queues is to be able to apply the inference of dynamic batch es.

    @Override
    public NDList processInput(TranslatorContext ctx, Pair<Mat, Rectangle> input) throws Exception {
        Mat frame = input.getKey().clone();
        Rectangle bbox = input.getValue();
        int x = (int) Math.max(0, bbox.getX());
        int y = (int) Math.max(0, bbox.getY());
        int w = Math.min(frame.width(), (int) (x + bbox.getWidth())) - x;
        int h = Math.min(frame.height(), (int) (y + bbox.getHeight())) - y;
        Rectangle croppedBBox = CVUtils.scale(frame, x, y, w, h);


        Queue cropped_bboxes = (Queue) ctx.getAttachment("cropped_bboxes");
        if (cropped_bboxes == null) {
            cropped_bboxes = new LinkedList<>();
            ctx.setAttachment("cropped_bboxes", cropped_bboxes);
        }
        cropped_bboxes.add(croppedBBox);

        NDArray array = ImageUtils.mat2Image(frame).toNDArray(ctx.getNDManager(), Image.Flag.COLOR);

        return pipeline.transform(new NDList(array));
    }

In addition, the key to this step is to call opencv on java to implement an affine transformation of the image.

    /**
     * Convert box coordinates to center and scale.
     * adapted from https://github.com/Microsoft/human-pose-estimation.pytorch
     *
     * @param mat
     * @param x
     * @param y
     * @param w
     * @param h
     * @return cropped box
     */
    public static Rectangle scale(Mat mat,
                                  double x, double y, double w, double h, double inputH, double inputW) {
        double inpCenterX = inputW / 2, inpCenterY = inputH / 2;
        double aspectRatio = inputW / inputH;

        double scaleMult = 1.25;
        // box_to_center_scale
        double centerX = x + 0.5 * w;
        double centerY = y + 0.5 * h;

        if (w > aspectRatio * h)
            h = w / aspectRatio;
        else if (w < aspectRatio * h)
            w = h * aspectRatio;
        double scaleX = w * scaleMult;
        double scaleY = h * scaleMult;

//        double rot = 0;
//        double sn = Math.sin(rot), cs = Math.cos(rot);
        // Get Affine Matrix
        Mat trans = getAffineTransform(centerX, centerY, inputW, inputH, scaleX, false);
        // affine transformation
        Imgproc.warpAffine(mat, mat, trans, new Size(inputW, inputH), Imgproc.INTER_LINEAR);

//        HighGui.imshow("person", mat);
        return new Rectangle(centerX - scaleX * 0.5, centerY - scaleY * 0.5, scaleX, scaleY);
    }

The effect of the transformation is roughly the same as the following

3. Data processing after inference

This part is actually rewritten following the alphapose source code. Vector calculation can use NDArray encapsulated in a DJL framework, even on a cpu faster than using a for loop.

    @Override
    public Joints processOutput(TranslatorContext ctx, NDList list) {

        NDArray pred = list.singletonOrThrow().toDevice(Device.cpu(), false);
        int numJoints = (int) pred.getShape().get(0);
        int height = (int) pred.getShape().get(1);
        int width = (int) pred.getShape().get(2);
        pred = Activation.sigmoid(pred.reshape(new Shape(1, numJoints, -1)));
        NDArray maxValues = pred.max(axis2, true).toType(DataType.FLOAT32, false);
        //normalized to probability
        NDArray heatmaps = pred
                .div(pred.sum(axis2, true))
                .reshape(1, numJoints, 1, height, width);

        // The edge probability
        NDArray hmX = heatmaps.sum(axis2n3);
        NDArray hmY = heatmaps.sum(axis2n4);
//        NDArray hmZ = heatmaps.sum(axis3n4);

        NDManager ndManager = NumpyUtils.ndManager;

        hmX = integralOp(hmX, ndManager);
        hmY = integralOp(hmY, ndManager);
//        hmZ = integralOp(hmZ, ndManager);

        NDArray coordX = hmX.sum(axis2, true);
        NDArray coordY = hmY.sum(axis2, true);

        NDArray predJoints = coordX
                .concat(coordY, 2)
                .reshape(1, numJoints, 2)
                .toType(DataType.FLOAT32, false);

        Rectangle bbox = (Rectangle) ((Queue) ctx.getAttachment("cropped_bboxes")).poll();
        double x = bbox.getX();
        double y = bbox.getY();
        double w = bbox.getWidth();
        double h = bbox.getHeight();
        double centerX = x + 0.5 * w, centerY = y + 0.5 * h;
        double scaleX = w;

        Mat trans = CVUtils.getAffineTransform(centerX, centerY, width, height, scaleX, true);
        NDArray ndTrans = CVUtils.transMat2NDArray(trans, ndManager).transpose(1, 0);

        predJoints = predJoints
                .concat(ONES_NDARRAY, 2);

        NDArray xys = predJoints.matMul(ndTrans);


        float[] flattened = xys.toFloatArray();
        float[] flattenedConfidence = maxValues.toFloatArray();

        List<Joint> joints = new ArrayList<>(numJoints);
        for (int i = 0; i < numJoints; ++i) {
            joints.add(new Joint(
                    flattened[2 * i],
                    flattened[2 * i + 1],
                    flattenedConfidence[i]));
        }
//        System.out.println(joints);
        return new Joints(joints);
    }

integralOp

    private static NDArray integralOp(NDArray hm, NDManager ndManager) {
        Shape hmShape = hm.getShape();
        NDArray arr = ndManager
                .arange(hmShape.get(hmShape
                        .dimension() - 1)).toType(DataType.FLOAT32, false);
        return hm.mul(arr);
    }

3. Combinatorial Model

This step is essentially to input an image to the yolo model, then input the output of the yolo model to sppe, and do a simple format conversion in between.

    static void detect(Mat frame,
                       YoloV5Detector detector,
                       ParallelPoseEstimator ppe) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
        Image img = ImageUtils.mat2Image(frame);
        long startTime = System.currentTimeMillis();
        try {
            DetectedObjects results = detector.detect(img);
            List<DetectedObject> detectedObjects = new ArrayList<>(results.getNumberOfObjects());
            List<Rectangle> jointsInput = new ArrayList<>(results.getNumberOfObjects());
            for (DetectedObject obj : results.<DetectedObject>items()) {
                if ("person".equals(obj.getClassName())) {
                    detectedObjects.add(obj);
                    jointsInput.add(obj.getBoundingBox().getBounds());
                }
            }
            List<Joints> joints = ppe.infer(frame, jointsInput);
            for (DetectedObject obj : detectedObjects) {
                BoundingBox bbox = obj.getBoundingBox();
                Rectangle rectangle = bbox.getBounds();
                String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability());
                rect.x = (int) rectangle.getX();
                rect.y = (int) rectangle.getY();
                rect.width = (int) rectangle.getWidth();
                rect.height = (int) rectangle.getHeight();
                // Picture frame
                Imgproc.rectangle(frame, rect, color, 2);
                //Paint Name
                Imgproc.putText(frame, showText,
                        new Point(rect.x, rect.y),
                        Imgproc.FONT_HERSHEY_COMPLEX,
                        rectangle.getWidth() / 200,
                        color);
            }
            for (Joints jointsItem : joints) {
                CVUtils.draw136KeypointsLight(frame, jointsItem);
            }
        } finally {

        }
        boolean showFPS = true;
        if (showFPS) {
            double fps = 1000.0 / (System.currentTimeMillis() - startTime);
            System.out.println(String.format("%.2f", fps));
            Imgproc.putText(frame, String.format("FPS: %.2f", fps),
                    new Point(0, 52),
                    Imgproc.FONT_HERSHEY_COMPLEX,
                    0.5,
                    ColorConst.COLOR_RED);
        }

    }

Implement results

Effect demonstration

Because the sppe model is a bit lightweight, some of the points are not particularly accurate, but enough.
In addition, this side actually does a simple pipelining like alphapose, which can be found in the project mentioned at the beginning.

Performance description

A performance test has been done here. The pure inference speed framework is basically the same as python. Use the for loop as little as possible for intermediate data processing, and use the native methods of NDArray or opencv, otherwise the performance may not be as good as python.
The SPPE section above is now 1.6 times faster than python with some optimization

Tags: Python Java OpenCV Algorithm Pytorch

Posted on Fri, 05 Nov 2021 12:39:10 -0400 by elenev