Str1k3rch0

Финал

Jul 21st, 2025
236
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.36 KB | Software | 0 0
  1. import cv2
  2. import pytesseract
  3. import os
  4. import re
  5. import time
  6. import joblib
  7. import numpy as np
  8. from datetime import datetime
  9. import tkinter as tk
  10. from tkinter import ttk
  11. from PIL import Image, ImageTk
  12.  
  13. # Load trained model
  14. from sklearn.ensemble import RandomForestClassifier
  15. model = joblib.load("model.pkl")
  16. IMG_SIZE = (32, 32)
  17.  
  18. # Tesseract config
  19. pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
  20. custom_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789'
  21.  
  22. # Output directories
  23. output_dir = "detections"
  24. os.makedirs(output_dir, exist_ok=True)
  25.  
  26. feedback_file = "feedback_log.csv"
  27. if not os.path.exists(feedback_file):
  28.     with open(feedback_file, "w") as f:
  29.         f.write("number,confidence,label,timestamp\n")
  30.  
  31. # Globals
  32. pending_feedback = False
  33. freeze_mode = False
  34. last_detected_number = None
  35. last_detection_time = 0
  36. cooldown_seconds = 5
  37. current_number = None
  38. current_confidence = None
  39. last_saved_image = None
  40.  
  41. # Active Learning
  42. new_data = []
  43. ACTIVE_BATCH = 10
  44.  
  45. # GUI setup
  46. root = tk.Tk()
  47. root.title("Digit Detector (Tesseract + Model)")
  48. root.geometry("1280x650")
  49. root.grid_columnconfigure(0, weight=1)
  50. root.grid_columnconfigure(1, weight=1)
  51. root.grid_rowconfigure(0, weight=1)
  52.  
  53. label_camera = ttk.Label(root)
  54. label_camera.grid(row=0, column=0, padx=5, pady=5, sticky="nsew")
  55.  
  56. label_processed = ttk.Label(root)
  57. label_processed.grid(row=0, column=1, padx=5, pady=5, sticky="nsew")
  58.  
  59. label_info = ttk.Label(root, text="🧠 Waiting for number...", font=('Arial', 14))
  60. label_info.grid(row=1, column=0, columnspan=2)
  61.  
  62. # Entry and button for manual correction
  63. entry_corrected = ttk.Entry(root)
  64. btn_submit_correction = ttk.Button(root, text="📤 Submit Correction", command=lambda: submit_manual_correction(entry_corrected.get()))
  65. entry_corrected.grid(row=2, column=0, padx=5, pady=5)
  66. btn_submit_correction.grid(row=2, column=1, padx=5, pady=5)
  67. entry_corrected.grid_remove()
  68. btn_submit_correction.grid_remove()
  69.  
  70. # Feedback buttons + Cancel & Retry
  71. btn_correct = ttk.Button(root, text="✅ Correct", command=lambda: save_feedback(True))
  72. btn_correct.grid(row=3, column=0, pady=10)
  73. btn_wrong = ttk.Button(root, text="❌ Wrong", command=lambda: save_feedback(False))
  74. btn_wrong.grid(row=3, column=1, pady=10)
  75. btn_cancel = ttk.Button(root, text="🛑 Cancel Freeze", command=lambda: cancel_freeze())
  76. btn_cancel.grid(row=4, column=0, pady=5)
  77. btn_retry = ttk.Button(root, text="🔁 Retry", command=lambda: force_retry())
  78. btn_retry.grid(row=4, column=1, pady=5)
  79.  
  80. # Capture
  81. cap = cv2.VideoCapture(0)
  82.  
  83. def predict_with_model(image):
  84.     img_resized = cv2.resize(image, IMG_SIZE)
  85.     img_flat = img_resized.flatten().reshape(1, -1)
  86.     prediction = model.predict(img_flat)[0]
  87.     return str(prediction)
  88.  
  89. def show_feedback_prompt():
  90.     label_info.config(text=f"❓ Was {current_number} correct? Use ✅ / ❌")
  91.  
  92. def cancel_freeze():
  93.     global freeze_mode, pending_feedback
  94.     freeze_mode = False
  95.     pending_feedback = False
  96.     entry_corrected.grid_remove()
  97.     btn_submit_correction.grid_remove()
  98.     label_info.config(text="⏹️ Freeze canceled. Resuming live...")
  99.  
  100. def force_retry():
  101.     global last_detection_time
  102.     # allow immediate retry on last_saved_image
  103.     last_detection_time = 0
  104.     label_info.config(text="🔁 Retry triggered.")
  105.  
  106. def save_feedback(correct):
  107.     global current_number, current_confidence, last_saved_image, pending_feedback, freeze_mode, new_data
  108.     if not current_number or last_saved_image is None:
  109.         return
  110.     ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  111.     label = "correct" if correct else "wrong"
  112.     with open(feedback_file, "a") as f:
  113.         f.write(f"{current_number},{current_confidence},{label},{ts}\n")
  114.     # save color crop if correct
  115.     folder = current_number if correct else 'unlabeled'
  116.     os.makedirs(os.path.join("dataset", folder), exist_ok=True)
  117.     cv2.imwrite(os.path.join("dataset", folder, f"{ts}.png"), last_saved_image)
  118.     if correct:
  119.         # active learning
  120.         new_data.append((last_saved_image, int(current_number)))
  121.         if len(new_data) >= ACTIVE_BATCH:
  122.             retrain_model()
  123.         pending_feedback = False
  124.         freeze_mode = False
  125.         label_info.config(text=f"✅ Saved: {current_number} as {label}")
  126.     else:
  127.         # enter freeze+retry mode
  128.         freeze_mode = True
  129.         pending_feedback = True
  130.         label_info.config(text="❌ Wrong detection. Awaiting retry or correction.")
  131.         entry_corrected.delete(0, tk.END)
  132.         entry_corrected.grid()
  133.         btn_submit_correction.grid()
  134.  
  135. def submit_manual_correction(val):
  136.     global current_number, last_saved_image, pending_feedback, freeze_mode, new_data
  137.     val = val.strip()
  138.     if not (val.isdigit() or val.lower()=='nothing'):
  139.         label_info.config(text="⚠️ Enter a valid number or 'nothing'")
  140.         return
  141.     ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  142.     folder = val if val.isdigit() else 'unlabeled'
  143.     os.makedirs(os.path.join("dataset", folder), exist_ok=True)
  144.     cv2.imwrite(os.path.join("dataset", folder, f"{ts}.png"), last_saved_image)
  145.     with open(feedback_file, "a") as f:
  146.         f.write(f"{val},{current_confidence},manual,{ts}\n")
  147.     # active learning
  148.     if val.isdigit():
  149.         new_data.append((last_saved_image, int(val)))
  150.         if len(new_data) >= ACTIVE_BATCH:
  151.             retrain_model()
  152.     pending_feedback = False
  153.     freeze_mode = False
  154.     entry_corrected.grid_remove()
  155.     btn_submit_correction.grid_remove()
  156.     label_info.config(text=f"📤 Manual correction saved: {val}")
  157.  
  158. def retrain_model():
  159.     global new_data, model
  160.     X, y = [], []
  161.     for lbl in os.listdir("dataset"):
  162.         if lbl == 'unlabeled': continue
  163.         p = os.path.join("dataset", lbl)
  164.         for fn in os.listdir(p):
  165.             img = cv2.imread(os.path.join(p, fn), cv2.IMREAD_GRAYSCALE)
  166.             if img is None: continue
  167.             X.append(cv2.resize(img, IMG_SIZE).flatten())
  168.             y.append(int(lbl))
  169.     for img_arr, lbl in new_data:
  170.         X.append(cv2.resize(img_arr, IMG_SIZE).flatten())
  171.         y.append(lbl)
  172.     if X:
  173.         clf = RandomForestClassifier(n_estimators=100)
  174.         clf.fit(np.array(X), np.array(y))
  175.         joblib.dump(clf, "model.pkl")
  176.         new_data = []
  177.         label_info.config(text="🤖 Model retrained.")
  178.  
  179. def update_frames():
  180.     global last_detected_number, last_detection_time
  181.     global current_number, current_confidence, last_saved_image, pending_feedback, freeze_mode
  182.  
  183.     # Freeze+Retry loop
  184.     if freeze_mode and pending_feedback and last_saved_image is not None:
  185.         # only retry when cooldown passed
  186.         if time.time() - last_detection_time > cooldown_seconds:
  187.             # re-run OCR
  188.             data = pytesseract.image_to_data(last_saved_image, config=custom_config,
  189.                                              output_type=pytesseract.Output.DICT)
  190.             num, conf = None, 0
  191.             for i, txt in enumerate(data['text']):
  192.                 txt = txt.strip()
  193.                 try: c = int(data['conf'][i])
  194.                 except: c = 0
  195.                 if re.fullmatch(r'\d{1,3}', txt) and c >= 60:
  196.                     num, conf = txt, c
  197.                     break
  198.             if num:
  199.                 current_number, current_confidence = num, conf
  200.                 label_info.config(text=f"🔁 Retry: Was {num} correct?")
  201.             last_detection_time = time.time()
  202.         root.after(100, update_frames)
  203.         return
  204.  
  205.     # Live capture
  206.     ret, frame = cap.read()
  207.     if not ret:
  208.         root.after(30, update_frames); return
  209.  
  210.     frame = cv2.resize(frame, (640, 480))
  211.     gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  212.     blur = cv2.GaussianBlur(gray, (3,3),0)
  213.     _, thresh = cv2.threshold(blur,170,255,cv2.THRESH_BINARY)
  214.  
  215.     # skip empty
  216.     if cv2.countNonZero(thresh) < 500:
  217.         label_info.config(text="🧠 No number detected")
  218.         root.after(30, update_frames); return
  219.  
  220.     # OCR detection
  221.     data = pytesseract.image_to_data(thresh, config=custom_config,
  222.                                      output_type=pytesseract.Output.DICT)
  223.     num, conf = None, 0
  224.     for i, txt in enumerate(data['text']):
  225.         txt = txt.strip()
  226.         try: c = int(data['conf'][i])
  227.         except: c = 0
  228.         if re.fullmatch(r'\d{1,3}', txt) and c >= 60:
  229.             num, conf = txt, c
  230.             break
  231.  
  232.     if num and (num != last_detected_number or time.time() - last_detection_time > cooldown_seconds):
  233.         timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  234.         cv2.imwrite(os.path.join(output_dir, f"detected_{timestamp}.png"), frame)
  235.         last_saved_image = thresh.copy()
  236.         last_detected_number = num
  237.         last_detection_time = time.time()
  238.         current_number, current_confidence = num, conf
  239.         pending_feedback = True
  240.         show_feedback_prompt()
  241.  
  242.     # update UI
  243.     cam_img = cv2.resize(frame, (600, 450))
  244.     proc_img = cv2.resize(thresh, (600, 450))
  245.     imgtk1 = ImageTk.PhotoImage(Image.fromarray(cv2.cvtColor(cam_img, cv2.COLOR_BGR2RGB)))
  246.     imgtk2 = ImageTk.PhotoImage(Image.fromarray(proc_img))
  247.     label_camera.config(image=imgtk1); label_camera.imgtk = imgtk1
  248.     label_processed.config(image=imgtk2); label_processed.imgtk = imgtk2
  249.  
  250.     root.after(30, update_frames)
  251.  
  252. update_frames()
  253. root.mainloop()
  254. cap.release()
  255. cv2.destroyAllWindows()
  256.  
Advertisement
Add Comment
Please, Sign In to add comment