Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def run_inference(image_path, onnx_model_path, conf_threshold=0.5):
- """
- Run inference using PyTorch-exported ONNX model
- Args:
- image_path: Path to input image
- onnx_model_path: Path to ONNX model
- conf_threshold: Confidence threshold for detections
- """
- session = ort.InferenceSession(onnx_model_path)
- # Load and preprocess image
- image = cv2.imread(image_path)
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- # Keep original dimensions for later scaling
- original_height, original_width = image_rgb.shape[:2]
- # Convert to float32 and normalize to [0-1] range
- input_tensor = image_rgb.astype(np.float32) / 255.0
- # Convert to NCHW format (PyTorch default)
- input_tensor = np.transpose(input_tensor, (2, 0, 1))
- input_tensor = np.expand_dims(input_tensor, 0)
- # Get input name
- input_name = session.get_inputs()[0].name
- # Run inference
- outputs = session.run(None, {input_name: input_tensor})
- # Parse outputs - adjust indices based on your model's output order
- boxes = outputs[0][0] # Assuming first batch
- scores = outputs[2][0] # Scores are usually the third output
- labels = outputs[1][0] # Labels/classes usually second output
- masks = outputs[3][0] if len(outputs) > 3 else None # Masks if present
- # Filter by confidence
- valid_detections = scores >= conf_threshold
- filtered_boxes = boxes[valid_detections]
- filtered_scores = scores[valid_detections]
- filtered_labels = labels[valid_detections]
- filtered_masks = masks[valid_detections] if masks is not None else None
- return {
- 'image': image_rgb,
- 'boxes': filtered_boxes,
- 'scores': filtered_scores,
- 'classes': filtered_labels,
- 'masks': filtered_masks,
- 'image_dims': (original_height, original_width)
- }
- def visualize_detections(results, class_names=None):
- """
- Visualize detection results
- """
- image = results['image'].copy()
- boxes = results['boxes']
- scores = results['scores']
- labels = results['classes']
- orig_height, orig_width = results['image_dims']
- # Draw each detection
- for box, score, class_id in zip(boxes, scores, labels):
- # Get coordinates - handle both absolute and normalized coordinates
- y1, x1, y2, x2 = box
- # If coordinates are normalized (0-1), scale to image size
- if max(box) <= 1.0:
- x1 = int(x1 * orig_width)
- y1 = int(y1 * orig_height)
- x2 = int(x2 * orig_width)
- y2 = int(y2 * orig_height)
- else:
- # If coordinates are absolute, ensure they're integers
- x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
- # Draw rectangle
- cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
- # Add label
- label = f"Class {int(class_id)}: {score:.2f}" if class_names is None else f"{class_names[int(class_id)]}: {score:.2f}"
- cv2.putText(image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
- # Display the result
- plt.figure(figsize=(12, 8))
- plt.imshow(image)
- plt.axis('off')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement