Guest User

projection mapping

a guest
Oct 3rd, 2025
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.84 KB | Photo | 0 0
  1. #!/usr/bin/env python3
  2. """
  3. projection_prewarp.py
  4.  
  5. Interactive tool to compute a homography between a source image and a photo of a projected result,
  6. then create a pre-warped image so that after projector-perspective the result appears correct.
  7.  
  8. Features:
  9. - Manual or automatic calibration (point correspondences)
  10. - Save and load homography
  11. - Batch mode: apply existing homography to new images
  12. - Interactive viewer with zoom and pan for precise point selection
  13. - Undo function to remove last point
  14.  
  15. Usage examples:
  16. Manual calibration:
  17.   python projection_prewarp.py --source source.png --photo projected.jpg --output prewarped.png
  18.  
  19. Automatic feature matching:
  20.   python projection_prewarp.py --source source.png --photo projected.jpg --output prewarped.png --mode auto
  21.  
  22. Apply precomputed homography:
  23.   python projection_prewarp.py --apply result_H.npy --input new.png --output new_prewarped.png --outsize 1920x1080
  24.  
  25. """
  26.  
  27. import argparse
  28. import json
  29. import os
  30. import sys
  31. from datetime import datetime
  32.  
  33. import cv2
  34. import numpy as np
  35.  
  36. # --- Globals used by mouse callbacks ---
  37. src_points = []
  38. dst_points = []
  39.  
  40. class ImageViewer:
  41.     def __init__(self, name, img, points_list):
  42.         self.name = name
  43.         self.img = img
  44.         self.points_list = points_list
  45.         self.zoom = 1.0
  46.         self.offset = np.array([0, 0], dtype=np.float32)
  47.         self.dragging = False
  48.         self.drag_start = None
  49.         cv2.namedWindow(self.name, cv2.WINDOW_NORMAL)
  50.         cv2.setMouseCallback(self.name, self.on_mouse)
  51.  
  52.     def on_mouse(self, event, x, y, flags, param=None):
  53.         if event == cv2.EVENT_LBUTTONDOWN:
  54.             orig_x = int((x + self.offset[0]) / self.zoom)
  55.             orig_y = int((y + self.offset[1]) / self.zoom)
  56.             self.points_list.append((orig_x, orig_y))
  57.         elif event == cv2.EVENT_RBUTTONDOWN:
  58.             self.dragging = True
  59.             self.drag_start = np.array([x, y], dtype=np.float32)
  60.         elif event == cv2.EVENT_MOUSEMOVE and self.dragging:
  61.             delta = np.array([x, y], dtype=np.float32) - self.drag_start
  62.             self.offset -= delta
  63.             self.offset = np.maximum(self.offset, 0)
  64.             self.drag_start = np.array([x, y], dtype=np.float32)
  65.         elif event == cv2.EVENT_RBUTTONUP:
  66.             self.dragging = False
  67.         elif event == cv2.EVENT_MOUSEWHEEL:
  68.             if flags > 0:
  69.                 self.zoom *= 1.2
  70.             else:
  71.                 self.zoom /= 1.2
  72.             self.zoom = max(0.2, min(self.zoom, 10.0))
  73.  
  74.     def show(self):
  75.         h, w = self.img.shape[:2]
  76.         view = cv2.resize(self.img, (int(w * self.zoom), int(h * self.zoom)))
  77.         # Crop window to max 1200x800
  78.         x0 = int(self.offset[0])
  79.         y0 = int(self.offset[1])
  80.         x1 = x0 + min(view.shape[1], 1200)
  81.         y1 = y0 + min(view.shape[0], 800)
  82.         crop = view[y0:y1, x0:x1].copy()
  83.         for i, (px, py) in enumerate(self.points_list):
  84.             sx = int(px * self.zoom - self.offset[0])
  85.             sy = int(py * self.zoom - self.offset[1])
  86.             if 0 <= sx < crop.shape[1] and 0 <= sy < crop.shape[0]:
  87.                 cv2.circle(crop, (sx, sy), 5, (0, 255, 0), -1)
  88.                 cv2.putText(crop, str(i + 1), (sx + 8, sy - 8),
  89.                             cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1)
  90.         cv2.imshow(self.name, crop)
  91.  
  92. def save_points(outfile, src_pts, dst_pts):
  93.     data = {
  94.         'timestamp': datetime.utcnow().isoformat() + 'Z',
  95.         'src_points': [[float(x), float(y)] for (x, y) in src_pts],
  96.         'dst_points': [[float(x), float(y)] for (x, y) in dst_pts],
  97.     }
  98.     with open(outfile, 'w', encoding='utf-8') as f:
  99.         json.dump(data, f, indent=2)
  100.     print(f"Saved points to {outfile}")
  101.  
  102. def try_auto_match(src_img_gray, dst_img_gray, max_features=2000, ratio_test=True):
  103.     orb = cv2.ORB_create(nfeatures=max_features)
  104.     kp1, des1 = orb.detectAndCompute(src_img_gray, None)
  105.     kp2, des2 = orb.detectAndCompute(dst_img_gray, None)
  106.     print(f"ORB: found {len(kp1)} keypoints in source, {len(kp2)} in photo")
  107.     if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
  108.         print("Not enough keypoints/descriptors for automatic matching")
  109.         return None
  110.     bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  111.     try:
  112.         matches = bf.knnMatch(des1, des2, k=2)
  113.     except Exception:
  114.         matches = [[m] for m in bf.match(des1, des2)]
  115.     good = []
  116.     if ratio_test:
  117.         for m_n in matches:
  118.             if len(m_n) >= 2:
  119.                 m, n = m_n[0], m_n[1]
  120.                 if m.distance < 0.75 * n.distance:
  121.                     good.append(m)
  122.     else:
  123.         for m_n in matches:
  124.             good.append(m_n[0])
  125.     print(f"Auto-matcher found {len(good)} good matches")
  126.     if len(good) < 8:
  127.         return None
  128.     pts_src = np.float32([kp1[m.queryIdx].pt for m in good])
  129.     pts_dst = np.float32([kp2[m.trainIdx].pt for m in good])
  130.     return pts_src, pts_dst, kp1, kp2, good
  131.  
  132. def compute_and_warp(src, photo, pts_src, pts_dst, out_size=None, save_prefix='result'):
  133.     pts_src = np.array(pts_src, dtype=np.float32)
  134.     pts_dst = np.array(pts_dst, dtype=np.float32)
  135.     if pts_src.shape[0] < 4 or pts_dst.shape[0] < 4:
  136.         raise ValueError('Need at least 4 point correspondences')
  137.     H, mask = cv2.findHomography(pts_src, pts_dst, cv2.RANSAC, 5.0)
  138.     if H is None:
  139.         raise RuntimeError('findHomography failed')
  140.     print('Homography (src -> photo):')
  141.     print(H)
  142.     np.save(save_prefix + '_H.npy', H)
  143.     with open(save_prefix + '_H.txt', 'w') as f:
  144.         f.write('\n'.join([' '.join(map(str, row)) for row in H]))
  145.     try:
  146.         H_inv = np.linalg.inv(H)
  147.     except np.linalg.LinAlgError:
  148.         ok, H_inv = cv2.invert(H)
  149.         if not ok:
  150.             raise RuntimeError('Homography is singular')
  151.     if out_size is None:
  152.         h_out, w_out = photo.shape[:2]
  153.     else:
  154.         w_out, h_out = out_size
  155.     prewarped = cv2.warpPerspective(src, H_inv, (w_out, h_out), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
  156.     cv2.imwrite(save_prefix + '_prewarped.png', prewarped)
  157.     return H, H_inv, prewarped
  158.  
  159. def apply_existing_homography(h_file, input_file, output_file, out_size=None):
  160.     H = np.load(h_file)
  161.     try:
  162.         H_inv = np.linalg.inv(H)
  163.     except np.linalg.LinAlgError:
  164.         ok, H_inv = cv2.invert(H)
  165.         if not ok:
  166.             raise RuntimeError('Homography is singular')
  167.     img = cv2.imread(input_file)
  168.     if img is None:
  169.         raise RuntimeError(f'Failed to load {input_file}')
  170.     if out_size is None:
  171.         out_w, out_h = img.shape[1], img.shape[0]
  172.     else:
  173.         out_w, out_h = out_size
  174.     prewarped = cv2.warpPerspective(img, H_inv, (out_w, out_h), flags=cv2.INTER_LINEAR)
  175.     cv2.imwrite(output_file, prewarped)
  176.     print(f"Applied homography from {h_file} to {input_file}, saved {output_file}")
  177.  
  178. def main():
  179.     parser = argparse.ArgumentParser(description='Compute or apply homography for projector prewarp')
  180.     parser.add_argument('--source', '-s', help='Source image (for calibration)')
  181.     parser.add_argument('--photo', '-p', help='Photo of projected image (for calibration)')
  182.     parser.add_argument('--output', '-o', default='prewarped.png', help='Output filename for prewarped image')
  183.     parser.add_argument('--mode', choices=['manual', 'auto'], default='manual', help='Calibration mode')
  184.     parser.add_argument('--outsize', help='Output size WIDTHxHEIGHT for prewarped image')
  185.     parser.add_argument('--save-prefix', default='result', help='Prefix for saved calibration files')
  186.     parser.add_argument('--apply', help='Use existing homography file (.npy)')
  187.     parser.add_argument('--input', help='Input image to prewarp using existing homography')
  188.     args = parser.parse_args()
  189.  
  190.     out_size = None
  191.     if args.outsize:
  192.         try:
  193.             w, h = args.outsize.split('x')
  194.             out_size = (int(w), int(h))
  195.         except Exception:
  196.             print('Invalid --outsize, expected WIDTHxHEIGHT')
  197.             sys.exit(1)
  198.  
  199.     if args.apply and args.input:
  200.         apply_existing_homography(args.apply, args.input, args.output, out_size)
  201.         return
  202.  
  203.     if not args.source or not args.photo:
  204.         print('For calibration you must provide --source and --photo')
  205.         sys.exit(1)
  206.  
  207.     src_img = cv2.imread(args.source, cv2.IMREAD_COLOR)
  208.     dst_img = cv2.imread(args.photo, cv2.IMREAD_COLOR)
  209.     if src_img is None or dst_img is None:
  210.         print('Failed to load images')
  211.         sys.exit(1)
  212.  
  213.     if args.mode == 'auto':
  214.         gray1 = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY)
  215.         gray2 = cv2.cvtColor(dst_img, cv2.COLOR_BGR2GRAY)
  216.         auto = try_auto_match(gray1, gray2)
  217.         if auto:
  218.             pts_src, pts_dst, kp1, kp2, good_matches = auto
  219.             matches_img = cv2.drawMatches(src_img, kp1, dst_img, kp2, good_matches[:60], None,
  220.                                           flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
  221.             cv2.imshow('Matches (auto)', matches_img)
  222.             key = cv2.waitKey(0)
  223.             cv2.destroyWindow('Matches (auto)')
  224.             if key != 27:
  225.                 H, H_inv, prewarped = compute_and_warp(src_img, dst_img, pts_src, pts_dst, out_size, args.save_prefix)
  226.                 cv2.imwrite(args.output, prewarped)
  227.                 save_points(args.save_prefix + '_points.json', pts_src.tolist(), pts_dst.tolist())
  228.                 return
  229.         print('Auto failed or cancelled, switching to manual')
  230.  
  231.     src_view = ImageViewer('Source', src_img, src_points)
  232.     dst_view = ImageViewer('Photo', dst_img, dst_points)
  233.  
  234.     while True:
  235.         src_view.show()
  236.         dst_view.show()
  237.         key = cv2.waitKey(20) & 0xFF
  238.         if key == ord('r'):
  239.             src_points.clear()
  240.             dst_points.clear()
  241.         elif key == ord('s'):
  242.             save_points(args.save_prefix + '_points.json', src_points, dst_points)
  243.         elif key == ord('c'):
  244.             if len(src_points) >= 4 and len(src_points) == len(dst_points):
  245.                 H, H_inv, prewarped = compute_and_warp(src_img, dst_img, src_points, dst_points, out_size, args.save_prefix)
  246.                 cv2.imwrite(args.output, prewarped)
  247.                 save_points(args.save_prefix + '_points.json', src_points, dst_points)
  248.             else:
  249.                 print('Need >=4 pairs and equal count')
  250.             break
  251.         elif key == ord('u') or key == 8:  # Undo last point (u or Backspace)
  252.             if src_points and (len(src_points) > len(dst_points)):
  253.                 removed = src_points.pop()
  254.                 print(f"Cofnięto punkt źródłowy {removed}")
  255.             elif dst_points:
  256.                 removed = dst_points.pop()
  257.                 print(f"Cofnięto punkt docelowy {removed}")
  258.         elif key == ord('q') or key == 27:
  259.             break
  260.     cv2.destroyAllWindows()
  261.  
  262. if __name__ == '__main__':
  263.     main()
  264.  
Advertisement
Add Comment
Please, Sign In to add comment