经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » Java相关 » Java » 查看文章
Java 调用 PaddleDetection 模型
来源:cnblogs  作者:hligy  时间:2023/2/20 15:17:32  对本文有异议

文章地址

介绍

训练好的模型要给业务调用,deepjavalibrary/djl:Java 中与引擎无关的深度学习框架 (github.com) 可以完成这件事,它支持使用 Java 调用 PyTorch、TensorFlow、MXNet、ONNX、PaddlePaddle 等引擎的模型(也支持部分引擎的模型构建和训练),本文只介绍调用 PaddlePaddle 引擎的模型调用。

调用模型流程:

  1. 导出模型(我更喜欢 ONNX 格式,它在 CPU 上推理也挺快的,可以快速测试,但有的算子不支持导出),确认模型输入输出
  2. 编写 Java 加载模型以及处理输入输出的代码

PaddleDetection 模型导出

导出模型

Anaconda 配置一个 PaddleDetection 的环境,cpu 版本即可(paddlepaddle==2.2.2),下载 PaddleDetection 工程,修改工程中 configs/runtime.yml 的属性 use_gpufalse

下面以 configs/pphuman/pedestrian_yolov3/pedestrian_yolov3_darknet.yml 为例介绍整个流程,导出模型:

  1. $ python tools/export_model.py -c configs/pphuman/pedestrian_yolov3/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams --output_dir pedestrian_yolov3_darknet

再转换为 ONNX:

  1. $ paddle2onnx --model_dir pedestrian_yolov3_darknet/pedestrian_yolov3_darknet --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 11 --save_file pedestrianYolov3.onnx --enable_onnx_checker True

确认输入输出

PaddleDetection 模型导出教程 中查看模型输入输出参数,再通过 Netorn 打开前面导出的 ONNX 模型详细确认

image-20230217150543704

Java 读取模型及推理

依赖

  1. <dependencies>
  2. <dependency>
  3. <groupId>ai.djl</groupId>
  4. <artifactId>api</artifactId>
  5. </dependency>
  6. <!--混合引擎,因为有的引擎 NDArray 不支持-->
  7. <dependency>
  8. <groupId>ai.djl.mxnet</groupId>
  9. <artifactId>mxnet-engine</artifactId>
  10. </dependency>
  11. <dependency>
  12. <groupId>ai.djl.onnxruntime</groupId>
  13. <artifactId>onnxruntime-engine</artifactId>
  14. </dependency>
  15. <dependency>
  16. <groupId>ai.djl</groupId>
  17. <artifactId>model-zoo</artifactId>
  18. </dependency>
  19. <!--使用 openpnp 的 opencv 加快图片读取-->
  20. <dependency>
  21. <groupId>ai.djl.opencv</groupId>
  22. <artifactId>opencv</artifactId>
  23. </dependency>
  24. </dependencies>
  25. <dependencyManagement>
  26. <dependencies>
  27. <dependency>
  28. <groupId>ai.djl</groupId>
  29. <artifactId>bom</artifactId>
  30. <version>0.20.0</version>
  31. <type>pom</type>
  32. <scope>import</scope>
  33. </dependency>
  34. </dependencies>
  35. </dependencyManagement>

处理输入输出

确定输入参数为图片原形状 im_shape、图片(需要归一化)image、比例 scale_factor,输出为预测框和预测数量,参数详细说明见前面提到的 PaddleDetection 模型导出教程中的说明。

  1. import ai.djl.modality.cv.Image;
  2. import ai.djl.modality.cv.output.BoundingBox;
  3. import ai.djl.modality.cv.output.DetectedObjects;
  4. import ai.djl.modality.cv.output.Rectangle;
  5. import ai.djl.modality.cv.transform.Normalize;
  6. import ai.djl.modality.cv.transform.Resize;
  7. import ai.djl.modality.cv.transform.ToTensor;
  8. import ai.djl.ndarray.NDArray;
  9. import ai.djl.ndarray.NDList;
  10. import ai.djl.ndarray.NDManager;
  11. import ai.djl.ndarray.types.DataType;
  12. import ai.djl.translate.NoBatchifyTranslator;
  13. import ai.djl.translate.Pipeline;
  14. import ai.djl.translate.TranslatorContext;
  15. import java.util.ArrayList;
  16. import java.util.Collections;
  17. import java.util.List;
  18. // 非批量输入输出应实现 NoBatchifyTranslator 接口,而不是 Translator
  19. public class PedestrianTranslator implements NoBatchifyTranslator<Image, DetectedObjects> {
  20. private final Pipeline pipeline;
  21. private final float threshold;
  22. private final List<String> classes;
  23. private final float imageWidth = 608f;
  24. private final float imageHeight = 608f;
  25. public PedestrianTranslator(float threshold) {
  26. // 定义图片预处理过程
  27. pipeline = new Pipeline();
  28. pipeline.add(new Resize((int) imageWidth, (int) imageHeight)) // resize 为模型图片输入格式,变成 608 * 608 * 3,HWC
  29. .add(new ToTensor()) // HWC -> CHW
  30. .add(new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f})) // 归一化
  31. .add(array -> array.expandDims(0)); // CHW -> NCHW
  32. // 预测阈值
  33. this.threshold = threshold;
  34. // 类别
  35. classes = Collections.singletonList("pedestrian");
  36. }
  37. @Override
  38. public NDList processInput(TranslatorContext ctx, Image input) {
  39. // 内存管理器,负责 NDArray 的内存回收
  40. NDManager manager = ctx.getNDManager();
  41. // 通过构造函数定义好的管道把图片转换到模型需要的图片格式。NDList 是一个集合,与 List<NDArray> 类似
  42. NDList ndList = pipeline.transform(new NDList(input.toNDArray(manager, Image.Flag.COLOR)));
  43. // 添加原图尺寸参数
  44. ndList.add(0, manager.create(new float[]{input.getHeight(), input.getWidth()}).expandDims(0));
  45. // 添加原图片尺寸与输入图片尺寸的比值
  46. ndList.add(manager.create(new float[]{input.getHeight() / 608f, input.getWidth() / 608f}).expandDims(0));
  47. return ndList;
  48. }
  49. @Override
  50. public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
  51. // 获取第一个参数预测结果,第二个预测数量没什么用
  52. NDArray result = list.get(0);
  53. /*
  54. result demo:
  55. ND: (3, 6) cpu() float32
  56. [[ 0. , 0.9759, 10.0805, 276.1631, 298.1623, 586.246 ],
  57. [ 0. , 0.955 , 486.306 , 221.0572, 585.966 , 480.4897],
  58. [ 0. , 0.8031, 295.0543, 206.104 , 395.3066, 485.3789],
  59. ]
  60. */
  61. // 获取类别
  62. int[] classIndices = result.get(":, 0").toType(DataType.INT32, true).flatten().toIntArray();
  63. // 获取置信度
  64. double[] probs = result.get(":, 1").toType(DataType.FLOAT64, true).toDoubleArray();
  65. // 获取预测的目标数量
  66. int detected = Math.toIntExact(probs.length);
  67. // 获取矩形框左上角 x 坐标比例(第 2 列)
  68. NDArray xMin = result.get(":, 2:3").clip(0, imageWidth).div(imageWidth);
  69. // 获取矩形框左上角 y 坐标比例(第 3 列)
  70. NDArray yMin = result.get(":, 3:4").clip(0, imageHeight).div(imageHeight);
  71. // 获取矩形框右上角 x 坐标比例(第 4 列)
  72. NDArray xMax = result.get(":, 4:5").clip(0, imageWidth).div(imageWidth);
  73. // 获取矩形框右上角 y 坐标比例(第 5 列)
  74. NDArray yMax = result.get(":, 5:6").clip(0, imageHeight).div(imageHeight);
  75. // 转为可以直接绘制的数据,分别是矩形框左上角的 x 和 y 坐标、矩形框的宽和高,均为比例
  76. float[] boxX = xMin.toFloatArray();
  77. float[] boxY = yMin.toFloatArray();
  78. float[] boxWidth = xMax.sub(xMin).toFloatArray();
  79. float[] boxHeight = yMax.sub(yMin).toFloatArray();
  80. // 封装成 DetectedObjects 对象输出
  81. List<String> retClasses = new ArrayList<>(detected);
  82. List<Double> retProbs = new ArrayList<>(detected);
  83. List<BoundingBox> retBB = new ArrayList<>(detected);
  84. for (int i = 0; i < detected; i++) {
  85. // 类别不存在或者置信度低于预测阈值则跳过
  86. if (classIndices[i] < 0 || probs[i] < threshold) {
  87. continue;
  88. }
  89. retClasses.add(classes.get(0));
  90. retProbs.add(probs[i]);
  91. retBB.add(new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]));
  92. }
  93. return new DetectedObjects(retClasses, retProbs, retBB);
  94. }
  95. }

这里涉及的 NDArray 操作比较多,使用官方实现的 Transform 和 Pipeline 可以简化代码,不过手动调 NDImageUtils 更清晰。简单说几个 API:

  1. expandDims:增加维度,比如 Pipeline 的一个 Transform Lambda 将 CHW 前面加一个维度变成 NCHW
  2. get:查看 NDIndex API(方法注释上均有代码样例说明)、百度 numpy 索引切片或 NDArray 教程,搞懂 :,
  3. clip:限制数值,数值越界就取该方法传入的值

加载模型

  1. import ai.djl.MalformedModelException;
  2. import ai.djl.modality.cv.Image;
  3. import ai.djl.modality.cv.output.DetectedObjects;
  4. import ai.djl.repository.zoo.Criteria;
  5. import ai.djl.repository.zoo.ModelNotFoundException;
  6. import ai.djl.repository.zoo.ZooModel;
  7. import ai.djl.training.util.ProgressBar;
  8. import java.io.IOException;
  9. import java.nio.file.Paths;
  10. public class Models {
  11. public static ZooModel<Image, DetectedObjects> getModel() throws ModelNotFoundException, MalformedModelException, IOException {
  12. return Criteria.builder()
  13. .optEngine("OnnxRuntime") // 选择引擎
  14. .setTypes(Image.class, DetectedObjects.class) // 设置输入输出
  15. .optModelPath(Paths.get("D:\\Repository\\Github\\PaddleDetection\\pedestrian_yolov3_darknet.onnx")) // 设置模型地址。Jar 包、Zip 包根据 API 自行配置
  16. .optProgress(new ProgressBar()) // 进度条
  17. .optTranslator(new PedestrianTranslator(.5f)) // 默认的转换器,不是线程安全的
  18. .build().loadModel();
  19. }
  20. }

推理

  1. import ai.djl.Device;
  2. import ai.djl.MalformedModelException;
  3. import ai.djl.inference.Predictor;
  4. import ai.djl.modality.cv.Image;
  5. import ai.djl.modality.cv.ImageFactory;
  6. import ai.djl.modality.cv.output.DetectedObjects;
  7. import ai.djl.repository.zoo.ModelNotFoundException;
  8. import ai.djl.repository.zoo.ZooModel;
  9. import ai.djl.translate.TranslateException;
  10. import java.io.IOException;
  11. import java.nio.file.Files;
  12. import java.nio.file.Paths;
  13. public class Inference {
  14. public static void main(String[] args) throws IOException, MalformedModelException, TranslateException, ModelNotFoundException {
  15. String imageFilePath = "C:\\Users\\DELL\\Desktop\\2.png";
  16. // 加载模型
  17. try (ZooModel<Image, DetectedObjects> model = Models.getModel()) {
  18. // 新建一个推理,使用 GPU
  19. try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(Device.gpu())) {
  20. Image image = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
  21. // 推理
  22. DetectedObjects result = predictor.predict(image);
  23. // 绘制矩形框
  24. image.drawBoundingBoxes(result);
  25. image.save(Files.newOutputStream(Paths.get("output.png")), "png");
  26. }
  27. }
  28. }
  29. }

CPU GPU 配置

没有配置 cuda 的话自动下载 CPU 所需的文件,有 cuda 的话会自动寻找匹配 cuda 版本的文件,目前官网上的 cuda 版本是 10.2 和 11.2。

也可以通过配置 jar 来指定 CPU 还是 GPU,以 ONNX 为例(详见DJL Hybrid engines ONNX):

  1. <dependency>
  2. <groupId>ai.djl.onnxruntime</groupId>
  3. <artifactId>onnxruntime-engine</artifactId>
  4. <version>0.20.0</version>
  5. <scope>runtime</scope>
  6. <exclusions>
  7. <exclusion>
  8. <groupId>com.microsoft.onnxruntime</groupId>
  9. <artifactId>onnxruntime</artifactId>
  10. </exclusion>
  11. </exclusions>
  12. </dependency>
  13. <dependency>
  14. <groupId>com.microsoft.onnxruntime</groupId>
  15. <artifactId>onnxruntime_gpu</artifactId>
  16. <version>1.13.1</version>
  17. <scope>runtime</scope>
  18. </dependency>

注意

  1. 最需要知道的是导出的模型的输入和输出,否则不知道怎么写 Translator
  2. DJL 运行所需的文件挺大的,它会在第一次运行时下载,网卡流量在动就等会吧(在 /${HOME}/.djl.ai/ 下)
  3. 通常第一次推理比较慢,建议预热一次
  4. 多线程建议每个线程一个 Predictor

Jupyter Notebook

附上可以直接运行的 notebook:d2l/paddledetection.ipynb at master · hligaty/d2l (github.com)。Maven 下载依赖比较慢,建议手动下载依赖放到 /${HOME}/.ivy2/cache/ 下。

参考与推荐

PaddleDetection 安装

PaddleDetection 模型导出教程

PaddleDetection 模型导出为 ONNX 格式教程

DJL 引擎

AIAS_人工智能加速器|Java SDK|中台|套件

PaddleOCR 的 Java 高性能部署

frankfliu/IJava:用于执行Java代码的Jupyter内核

原文链接:https://www.cnblogs.com/hligy/p/17137087.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号