| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import numpy as np
- from hbm_runtime import HB_HBMRuntime
- import os
- import cv2
- hb_dtype_map = {
- "U8": np.uint8,
- "S8": np.int8,
- "F32": np.float32,
- "F16": np.float16,
- "U16": np.uint16,
- "S16": np.int16,
- "S32": np.int32,
- "U32": np.uint32,
- "BOOL8": np.bool_,
- }
- MODEL_PATH = "../models/cowCheck/det/det_640x640.hbm"
- # 1️⃣ 加载模型
- model = HB_HBMRuntime(MODEL_PATH)
- model_name = model.model_names[0]
- print("Loaded model:", model_name)
- # 2️⃣ 查看输入信息
- input_names = model.input_names[model_name]
- input_shapes = model.input_shapes[model_name]
- input_dtypes = model.input_dtypes[model_name]
- print("Input info:")
- for name in input_names:
- print(" Name:", name)
- print(" Shape:", input_shapes[name])
- print(" Dtype:", input_dtypes[name])
- # 3️⃣ 使用真实图片测试
- IMAGE_DIR = "/home/sunrise/opt/dev/project/tools/images_20260316_142840"
- image_files = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(".jpg")]
- image_files.sort()
- print(f"\nFound {len(image_files)} images.")
- for img_name in image_files:
- img_path = os.path.join(IMAGE_DIR, img_name)
- img = cv2.imread(img_path)
- if img is None:
- print(f"Skip invalid image: {img_name}")
- continue
- # ===== YOLOv8 Letterbox 预处理(保持比例)=====
- h0, w0 = img.shape[:2]
- target_size = 640
- r = min(target_size / h0, target_size / w0)
- new_w, new_h = int(w0 * r), int(h0 * r)
- resized = cv2.resize(img, (new_w, new_h))
- pad_w = target_size - new_w
- pad_h = target_size - new_h
- top = pad_h // 2
- bottom = pad_h - top
- left = pad_w // 2
- right = pad_w - left
- img_letterbox = cv2.copyMakeBorder(
- resized, top, bottom, left, right,
- cv2.BORDER_CONSTANT, value=(114, 114, 114)
- )
- img_display = img_letterbox.copy()
- img_input_float = img_letterbox.astype(np.float32) / 255.0
- input_tensors = {}
- for name in input_names:
- shape = input_shapes[name]
- dtype = hb_dtype_map.get(input_dtypes[name].name, np.float32)
- # 默认假设输入是 NCHW F32
- if len(shape) == 4:
- # HWC -> CHW
- img_input = np.transpose(img_input_float, (2, 0, 1))
- img_input = np.expand_dims(img_input, axis=0)
- img_input = img_input.astype(dtype)
- input_tensors[name] = img_input
- else:
- # 其他情况 fallback
- input_tensors[name] = np.zeros(shape, dtype=dtype)
- results = model.run(input_tensors)
- print(f"\nImage: {img_name}")
- for output_name, output_data in results[model_name].items():
- print(" Output:", output_name, "shape:", output_data.shape)
- # ===== YOLOv8 解码 =====
- # (1, 7, 8400) -> (8400, 7)
- preds = output_data[0].transpose(1, 0)
- boxes = preds[:, 0:4] # xywh (already decoded in export)
- cls_scores = preds[:, 4:] # class logits
- # YOLOv8 uses sigmoid (NOT softmax)
- cls_probs = 1.0 / (1.0 + np.exp(-cls_scores))
- cls_ids = np.argmax(cls_probs, axis=1)
- scores = np.max(cls_probs, axis=1)
- # 提高置信度阈值(避免全屏框)
- CONF_THRESH = 0.6
- # 过滤低置信度
- keep_mask = scores > CONF_THRESH
- boxes = boxes[keep_mask]
- scores = scores[keep_mask]
- cls_ids = cls_ids[keep_mask]
- if len(boxes) == 0:
- cv2.imshow("det_result", img_display)
- if cv2.waitKey(1) & 0xFF == ord('q'):
- break
- continue
- # 转换为 xyxy
- xyxy_boxes = []
- for b in boxes:
- cx, cy, w, h = b
- x1 = cx - w / 2
- y1 = cy - h / 2
- x2 = cx + w / 2
- y2 = cy + h / 2
- xyxy_boxes.append([x1, y1, x2 - x1, y2 - y1]) # cv2 NMSBoxes uses x,y,w,h
- # OpenCV NMS
- indices = cv2.dnn.NMSBoxes(
- xyxy_boxes,
- scores.tolist(),
- CONF_THRESH,
- 0.45
- )
- if len(indices) > 0:
- indices = indices.flatten()
- else:
- indices = []
- for i in indices:
- cx, cy, w, h = boxes[i]
- x1 = int(cx - w / 2)
- y1 = int(cy - h / 2)
- x2 = int(cx + w / 2)
- y2 = int(cy + h / 2)
- cv2.rectangle(img_display, (x1, y1), (x2, y2), (0, 255, 0), 2)
- cv2.putText(
- img_display,
- f"{cls_ids[i]}:{scores[i]:.2f}",
- (x1, y1 - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (0, 255, 0),
- 1
- )
- # ===== 视频流方式显示 =====
- cv2.imshow("det_result", img_display)
- # 1ms 刷新,按 q 退出
- if cv2.waitKey(1) & 0xFF == ord('q'):
- break
- cv2.destroyAllWindows()
|