det_test.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import numpy as np
  2. from hbm_runtime import HB_HBMRuntime
  3. import os
  4. import cv2
  5. hb_dtype_map = {
  6. "U8": np.uint8,
  7. "S8": np.int8,
  8. "F32": np.float32,
  9. "F16": np.float16,
  10. "U16": np.uint16,
  11. "S16": np.int16,
  12. "S32": np.int32,
  13. "U32": np.uint32,
  14. "BOOL8": np.bool_,
  15. }
  16. MODEL_PATH = "../models/cowCheck/det/det_640x640.hbm"
  17. # 1️⃣ 加载模型
  18. model = HB_HBMRuntime(MODEL_PATH)
  19. model_name = model.model_names[0]
  20. print("Loaded model:", model_name)
  21. # 2️⃣ 查看输入信息
  22. input_names = model.input_names[model_name]
  23. input_shapes = model.input_shapes[model_name]
  24. input_dtypes = model.input_dtypes[model_name]
  25. print("Input info:")
  26. for name in input_names:
  27. print(" Name:", name)
  28. print(" Shape:", input_shapes[name])
  29. print(" Dtype:", input_dtypes[name])
  30. # 3️⃣ 使用真实图片测试
  31. IMAGE_DIR = "/home/sunrise/opt/dev/project/tools/images_20260316_142840"
  32. image_files = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(".jpg")]
  33. image_files.sort()
  34. print(f"\nFound {len(image_files)} images.")
  35. for img_name in image_files:
  36. img_path = os.path.join(IMAGE_DIR, img_name)
  37. img = cv2.imread(img_path)
  38. if img is None:
  39. print(f"Skip invalid image: {img_name}")
  40. continue
  41. # ===== YOLOv8 Letterbox 预处理(保持比例)=====
  42. h0, w0 = img.shape[:2]
  43. target_size = 640
  44. r = min(target_size / h0, target_size / w0)
  45. new_w, new_h = int(w0 * r), int(h0 * r)
  46. resized = cv2.resize(img, (new_w, new_h))
  47. pad_w = target_size - new_w
  48. pad_h = target_size - new_h
  49. top = pad_h // 2
  50. bottom = pad_h - top
  51. left = pad_w // 2
  52. right = pad_w - left
  53. img_letterbox = cv2.copyMakeBorder(
  54. resized, top, bottom, left, right,
  55. cv2.BORDER_CONSTANT, value=(114, 114, 114)
  56. )
  57. img_display = img_letterbox.copy()
  58. img_input_float = img_letterbox.astype(np.float32) / 255.0
  59. input_tensors = {}
  60. for name in input_names:
  61. shape = input_shapes[name]
  62. dtype = hb_dtype_map.get(input_dtypes[name].name, np.float32)
  63. # 默认假设输入是 NCHW F32
  64. if len(shape) == 4:
  65. # HWC -> CHW
  66. img_input = np.transpose(img_input_float, (2, 0, 1))
  67. img_input = np.expand_dims(img_input, axis=0)
  68. img_input = img_input.astype(dtype)
  69. input_tensors[name] = img_input
  70. else:
  71. # 其他情况 fallback
  72. input_tensors[name] = np.zeros(shape, dtype=dtype)
  73. results = model.run(input_tensors)
  74. print(f"\nImage: {img_name}")
  75. for output_name, output_data in results[model_name].items():
  76. print(" Output:", output_name, "shape:", output_data.shape)
  77. # ===== YOLOv8 解码 =====
  78. # (1, 7, 8400) -> (8400, 7)
  79. preds = output_data[0].transpose(1, 0)
  80. boxes = preds[:, 0:4] # xywh (already decoded in export)
  81. cls_scores = preds[:, 4:] # class logits
  82. # YOLOv8 uses sigmoid (NOT softmax)
  83. cls_probs = 1.0 / (1.0 + np.exp(-cls_scores))
  84. cls_ids = np.argmax(cls_probs, axis=1)
  85. scores = np.max(cls_probs, axis=1)
  86. # 提高置信度阈值(避免全屏框)
  87. CONF_THRESH = 0.6
  88. # 过滤低置信度
  89. keep_mask = scores > CONF_THRESH
  90. boxes = boxes[keep_mask]
  91. scores = scores[keep_mask]
  92. cls_ids = cls_ids[keep_mask]
  93. if len(boxes) == 0:
  94. cv2.imshow("det_result", img_display)
  95. if cv2.waitKey(1) & 0xFF == ord('q'):
  96. break
  97. continue
  98. # 转换为 xyxy
  99. xyxy_boxes = []
  100. for b in boxes:
  101. cx, cy, w, h = b
  102. x1 = cx - w / 2
  103. y1 = cy - h / 2
  104. x2 = cx + w / 2
  105. y2 = cy + h / 2
  106. xyxy_boxes.append([x1, y1, x2 - x1, y2 - y1]) # cv2 NMSBoxes uses x,y,w,h
  107. # OpenCV NMS
  108. indices = cv2.dnn.NMSBoxes(
  109. xyxy_boxes,
  110. scores.tolist(),
  111. CONF_THRESH,
  112. 0.45
  113. )
  114. if len(indices) > 0:
  115. indices = indices.flatten()
  116. else:
  117. indices = []
  118. for i in indices:
  119. cx, cy, w, h = boxes[i]
  120. x1 = int(cx - w / 2)
  121. y1 = int(cy - h / 2)
  122. x2 = int(cx + w / 2)
  123. y2 = int(cy + h / 2)
  124. cv2.rectangle(img_display, (x1, y1), (x2, y2), (0, 255, 0), 2)
  125. cv2.putText(
  126. img_display,
  127. f"{cls_ids[i]}:{scores[i]:.2f}",
  128. (x1, y1 - 5),
  129. cv2.FONT_HERSHEY_SIMPLEX,
  130. 0.5,
  131. (0, 255, 0),
  132. 1
  133. )
  134. # ===== 视频流方式显示 =====
  135. cv2.imshow("det_result", img_display)
  136. # 1ms 刷新,按 q 退出
  137. if cv2.waitKey(1) & 0xFF == ord('q'):
  138. break
  139. cv2.destroyAllWindows()