First, here is the complete project code
Implementing the environment
- java 1.8
- djl 1.12.0
- opencv-java 4.5.1(maven has javacv installed, manual configuration is not considered for the moment)
Implementation Principle
What the DJL framework can do for us
The DJL framework provides us with adaptions to implement multiple inference engines in java, so we can easily complete inference tasks in Java after exporting models in libtorch, onnx, mxnet and other formats.
DJL cannot perform complex post-computational tasks for us, so we may need to use the opencv framework or the built-in NDArray objects in DJL to perform calculations beyond reasoning.
Alphapose
Alphapose's inference model consists of two main parts
- yolo target detector
- alphapose's own single-person posture estimation network (SPPE)
In the middle of the two models, alphapose uses an affine transformation to scale the yolo detection results and a corresponding inverse operation to restore the coordinate system of the SPPE output.
Implementation Steps
1. Exporting the alphapose model
Export yolov5
There are many ways to export yolov5 on the web. You can refer to my blog
Exporting a single-person posture estimation network
2. Implementing Translator
Using YoloTranslator
You can refer to my blog (or just that one) 😅)
Implement SPPETranslator
Translator is a reasoning template provided to us by DJL. We can override the methods in the template to complete data preprocessing and subsequent processing beyond reasoning.
1. We define SPPETranslator
public class SPPETranslator extends BasePairTranslator<Mat, Rectangle, Joints> { ... }
BasePairTranslator <I, P, O>is a class I encapsulate myself. I represents the input type, O represents the output type, P is because our input contains two bounding boxes, the original image and the yolo output.
2. Data Preprocessing
Here, ctx.getAttachmen is used to save a new dimension from the affine transformation, which is then used to restore the coordinates of the response.
The reason cropped_bboxes uses queues is to be able to apply the inference of dynamic batch es.
@Override public NDList processInput(TranslatorContext ctx, Pair<Mat, Rectangle> input) throws Exception { Mat frame = input.getKey().clone(); Rectangle bbox = input.getValue(); int x = (int) Math.max(0, bbox.getX()); int y = (int) Math.max(0, bbox.getY()); int w = Math.min(frame.width(), (int) (x + bbox.getWidth())) - x; int h = Math.min(frame.height(), (int) (y + bbox.getHeight())) - y; Rectangle croppedBBox = CVUtils.scale(frame, x, y, w, h); Queue cropped_bboxes = (Queue) ctx.getAttachment("cropped_bboxes"); if (cropped_bboxes == null) { cropped_bboxes = new LinkedList<>(); ctx.setAttachment("cropped_bboxes", cropped_bboxes); } cropped_bboxes.add(croppedBBox); NDArray array = ImageUtils.mat2Image(frame).toNDArray(ctx.getNDManager(), Image.Flag.COLOR); return pipeline.transform(new NDList(array)); }
In addition, the key to this step is to call opencv on java to implement an affine transformation of the image.
/** * Convert box coordinates to center and scale. * adapted from https://github.com/Microsoft/human-pose-estimation.pytorch * * @param mat * @param x * @param y * @param w * @param h * @return cropped box */ public static Rectangle scale(Mat mat, double x, double y, double w, double h, double inputH, double inputW) { double inpCenterX = inputW / 2, inpCenterY = inputH / 2; double aspectRatio = inputW / inputH; double scaleMult = 1.25; // box_to_center_scale double centerX = x + 0.5 * w; double centerY = y + 0.5 * h; if (w > aspectRatio * h) h = w / aspectRatio; else if (w < aspectRatio * h) w = h * aspectRatio; double scaleX = w * scaleMult; double scaleY = h * scaleMult; // double rot = 0; // double sn = Math.sin(rot), cs = Math.cos(rot); // Get Affine Matrix Mat trans = getAffineTransform(centerX, centerY, inputW, inputH, scaleX, false); // affine transformation Imgproc.warpAffine(mat, mat, trans, new Size(inputW, inputH), Imgproc.INTER_LINEAR); // HighGui.imshow("person", mat); return new Rectangle(centerX - scaleX * 0.5, centerY - scaleY * 0.5, scaleX, scaleY); }
The effect of the transformation is roughly the same as the following
3. Data processing after inference
This part is actually rewritten following the alphapose source code. Vector calculation can use NDArray encapsulated in a DJL framework, even on a cpu faster than using a for loop.
@Override public Joints processOutput(TranslatorContext ctx, NDList list) { NDArray pred = list.singletonOrThrow().toDevice(Device.cpu(), false); int numJoints = (int) pred.getShape().get(0); int height = (int) pred.getShape().get(1); int width = (int) pred.getShape().get(2); pred = Activation.sigmoid(pred.reshape(new Shape(1, numJoints, -1))); NDArray maxValues = pred.max(axis2, true).toType(DataType.FLOAT32, false); //normalized to probability NDArray heatmaps = pred .div(pred.sum(axis2, true)) .reshape(1, numJoints, 1, height, width); // The edge probability NDArray hmX = heatmaps.sum(axis2n3); NDArray hmY = heatmaps.sum(axis2n4); // NDArray hmZ = heatmaps.sum(axis3n4); NDManager ndManager = NumpyUtils.ndManager; hmX = integralOp(hmX, ndManager); hmY = integralOp(hmY, ndManager); // hmZ = integralOp(hmZ, ndManager); NDArray coordX = hmX.sum(axis2, true); NDArray coordY = hmY.sum(axis2, true); NDArray predJoints = coordX .concat(coordY, 2) .reshape(1, numJoints, 2) .toType(DataType.FLOAT32, false); Rectangle bbox = (Rectangle) ((Queue) ctx.getAttachment("cropped_bboxes")).poll(); double x = bbox.getX(); double y = bbox.getY(); double w = bbox.getWidth(); double h = bbox.getHeight(); double centerX = x + 0.5 * w, centerY = y + 0.5 * h; double scaleX = w; Mat trans = CVUtils.getAffineTransform(centerX, centerY, width, height, scaleX, true); NDArray ndTrans = CVUtils.transMat2NDArray(trans, ndManager).transpose(1, 0); predJoints = predJoints .concat(ONES_NDARRAY, 2); NDArray xys = predJoints.matMul(ndTrans); float[] flattened = xys.toFloatArray(); float[] flattenedConfidence = maxValues.toFloatArray(); List<Joint> joints = new ArrayList<>(numJoints); for (int i = 0; i < numJoints; ++i) { joints.add(new Joint( flattened[2 * i], flattened[2 * i + 1], flattenedConfidence[i])); } // System.out.println(joints); return new Joints(joints); }
integralOp
private static NDArray integralOp(NDArray hm, NDManager ndManager) { Shape hmShape = hm.getShape(); NDArray arr = ndManager .arange(hmShape.get(hmShape .dimension() - 1)).toType(DataType.FLOAT32, false); return hm.mul(arr); }
3. Combinatorial Model
This step is essentially to input an image to the yolo model, then input the output of the yolo model to sppe, and do a simple format conversion in between.
static void detect(Mat frame, YoloV5Detector detector, ParallelPoseEstimator ppe) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException { Image img = ImageUtils.mat2Image(frame); long startTime = System.currentTimeMillis(); try { DetectedObjects results = detector.detect(img); List<DetectedObject> detectedObjects = new ArrayList<>(results.getNumberOfObjects()); List<Rectangle> jointsInput = new ArrayList<>(results.getNumberOfObjects()); for (DetectedObject obj : results.<DetectedObject>items()) { if ("person".equals(obj.getClassName())) { detectedObjects.add(obj); jointsInput.add(obj.getBoundingBox().getBounds()); } } List<Joints> joints = ppe.infer(frame, jointsInput); for (DetectedObject obj : detectedObjects) { BoundingBox bbox = obj.getBoundingBox(); Rectangle rectangle = bbox.getBounds(); String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability()); rect.x = (int) rectangle.getX(); rect.y = (int) rectangle.getY(); rect.width = (int) rectangle.getWidth(); rect.height = (int) rectangle.getHeight(); // Picture frame Imgproc.rectangle(frame, rect, color, 2); //Paint Name Imgproc.putText(frame, showText, new Point(rect.x, rect.y), Imgproc.FONT_HERSHEY_COMPLEX, rectangle.getWidth() / 200, color); } for (Joints jointsItem : joints) { CVUtils.draw136KeypointsLight(frame, jointsItem); } } finally { } boolean showFPS = true; if (showFPS) { double fps = 1000.0 / (System.currentTimeMillis() - startTime); System.out.println(String.format("%.2f", fps)); Imgproc.putText(frame, String.format("FPS: %.2f", fps), new Point(0, 52), Imgproc.FONT_HERSHEY_COMPLEX, 0.5, ColorConst.COLOR_RED); } }
Implement results
Effect demonstration
Because the sppe model is a bit lightweight, some of the points are not particularly accurate, but enough.
In addition, this side actually does a simple pipelining like alphapose, which can be found in the project mentioned at the beginning.
Performance description
A performance test has been done here. The pure inference speed framework is basically the same as python. Use the for loop as little as possible for intermediate data processing, and use the native methods of NDArray or opencv, otherwise the performance may not be as good as python.
The SPPE section above is now 1.6 times faster than python with some optimization