Advertisement
Guest User

Untitled

a guest
Oct 24th, 2024
48
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.25 KB | Software | 0 0
  1. def run_inference(image_path, onnx_model_path, conf_threshold=0.5):
  2.     """
  3.    Run inference using PyTorch-exported ONNX model
  4.    
  5.    Args:
  6.        image_path: Path to input image
  7.        onnx_model_path: Path to ONNX model
  8.        conf_threshold: Confidence threshold for detections
  9.    """
  10.     session = ort.InferenceSession(onnx_model_path)
  11.    
  12.     # Load and preprocess image
  13.     image = cv2.imread(image_path)
  14.     image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  15.    
  16.     # Keep original dimensions for later scaling
  17.     original_height, original_width = image_rgb.shape[:2]
  18.    
  19.     # Convert to float32 and normalize to [0-1] range
  20.     input_tensor = image_rgb.astype(np.float32) / 255.0
  21.    
  22.     # Convert to NCHW format (PyTorch default)
  23.     input_tensor = np.transpose(input_tensor, (2, 0, 1))
  24.     input_tensor = np.expand_dims(input_tensor, 0)
  25.    
  26.     # Get input name
  27.     input_name = session.get_inputs()[0].name
  28.    
  29.     # Run inference
  30.     outputs = session.run(None, {input_name: input_tensor})
  31.    
  32.     # Parse outputs - adjust indices based on your model's output order
  33.     boxes = outputs[0][0]  # Assuming first batch
  34.     scores = outputs[2][0]  # Scores are usually the third output
  35.     labels = outputs[1][0]  # Labels/classes usually second output
  36.     masks = outputs[3][0] if len(outputs) > 3 else None  # Masks if present
  37.    
  38.     # Filter by confidence
  39.     valid_detections = scores >= conf_threshold
  40.     filtered_boxes = boxes[valid_detections]
  41.     filtered_scores = scores[valid_detections]
  42.     filtered_labels = labels[valid_detections]
  43.     filtered_masks = masks[valid_detections] if masks is not None else None
  44.    
  45.     return {
  46.         'image': image_rgb,
  47.         'boxes': filtered_boxes,
  48.         'scores': filtered_scores,
  49.         'classes': filtered_labels,
  50.         'masks': filtered_masks,
  51.         'image_dims': (original_height, original_width)
  52.     }
  53.  
  54. def visualize_detections(results, class_names=None):
  55.     """
  56.    Visualize detection results
  57.    """
  58.     image = results['image'].copy()
  59.     boxes = results['boxes']
  60.     scores = results['scores']
  61.     labels = results['classes']
  62.     orig_height, orig_width = results['image_dims']
  63.    
  64.     # Draw each detection
  65.     for box, score, class_id in zip(boxes, scores, labels):
  66.         # Get coordinates - handle both absolute and normalized coordinates
  67.         y1, x1, y2, x2 = box
  68.        
  69.         # If coordinates are normalized (0-1), scale to image size
  70.         if max(box) <= 1.0:
  71.             x1 = int(x1 * orig_width)
  72.             y1 = int(y1 * orig_height)
  73.             x2 = int(x2 * orig_width)
  74.             y2 = int(y2 * orig_height)
  75.         else:
  76.             # If coordinates are absolute, ensure they're integers
  77.             x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
  78.        
  79.         # Draw rectangle
  80.         cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
  81.        
  82.         # Add label
  83.         label = f"Class {int(class_id)}: {score:.2f}" if class_names is None else f"{class_names[int(class_id)]}: {score:.2f}"
  84.         cv2.putText(image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
  85.    
  86.     # Display the result
  87.     plt.figure(figsize=(12, 8))
  88.     plt.imshow(image)
  89.     plt.axis('off')
  90.     plt.show()
  91.    
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement