Str1k3rch0

Digit Detector App 0.5

Jul 21st, 2025 (edited)
154
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.39 KB | Software | 0 0
  1. import cv2
  2. import pytesseract
  3. import os
  4. import re
  5. import time
  6. import joblib
  7. import numpy as np
  8. from datetime import datetime
  9. import tkinter as tk
  10. from tkinter import ttk
  11. from PIL import Image, ImageTk
  12. from sklearn.ensemble import RandomForestClassifier
  13.  
  14. # Constants
  15. IMG_SIZE = (32, 32)
  16. COOLDOWN = 5  # seconds
  17. MIN_OCR_CONF = 60
  18. ACTIVE_LEARNING_BATCH = 10
  19. SCAN_INTERVAL = 2.0  # seconds between rescans of frozen frame
  20. DATASET_DIR = "dataset"
  21. MODEL_PATH = "model.pkl"
  22.  
  23. class DigitDetectorApp:
  24.     def __init__(self, model_path=MODEL_PATH):
  25.         # Load or initialize model
  26.         if os.path.exists(model_path):
  27.             self.model = joblib.load(model_path)
  28.         else:
  29.             self.model = RandomForestClassifier(n_estimators=100)
  30.         self.new_data = []
  31.  
  32.         # Tesseract config
  33.         pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
  34.         self.ocr_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789'
  35.  
  36.         # Output dirs
  37.         os.makedirs("detections", exist_ok=True)
  38.         os.makedirs(DATASET_DIR, exist_ok=True)
  39.         self.feedback_file = "feedback_log.csv"
  40.         if not os.path.exists(self.feedback_file):
  41.             with open(self.feedback_file, "w") as f:
  42.                 f.write("number,confidence,label,timestamp\n")
  43.  
  44.         # State
  45.         self.pending_feedback = False
  46.         self.freeze_mode = False
  47.         self.manual_correction_mode = False
  48.         self.last_detected = None
  49.         self.last_time = 0
  50.         self.last_scan_time = 0
  51.         self.history = []
  52.         self.current = None
  53.  
  54.         # GUI setup
  55.         self.root = tk.Tk()
  56.         self.root.title("Digit Detector App")
  57.         self.root.geometry("1280x650")
  58.         self._build_gui()
  59.  
  60.         # Video capture
  61.         self.cap = cv2.VideoCapture(0)
  62.         self.update_frames()
  63.         self.root.mainloop()
  64.  
  65.     def _build_gui(self):
  66.         frame = self.root
  67.         frame.grid_columnconfigure(0, weight=1)
  68.         frame.grid_columnconfigure(1, weight=1)
  69.         frame.grid_rowconfigure(0, weight=1)
  70.  
  71.         self.label_cam = ttk.Label(frame)
  72.         self.label_cam.grid(row=0, column=0, padx=5, pady=5, sticky="nsew")
  73.         self.label_proc = ttk.Label(frame)
  74.         self.label_proc.grid(row=0, column=1, padx=5, pady=5, sticky="nsew")
  75.  
  76.         self.label_info = ttk.Label(frame, text="🧠 Waiting...", font=('Arial', 14))
  77.         self.label_info.grid(row=1, column=0, columnspan=2)
  78.  
  79.         btn_frame = ttk.Frame(frame)
  80.         btn_frame.grid(row=2, column=0, columnspan=2)
  81.         ttk.Button(btn_frame, text="✅ Correct", command=lambda: self.save_feedback(True)).pack(side="left", padx=5)
  82.         ttk.Button(btn_frame, text="❌ Wrong", command=lambda: self.save_feedback(False)).pack(side="left", padx=5)
  83.         ttk.Button(btn_frame, text="⟲ Undo", command=self.undo).pack(side="left", padx=5)
  84.         ttk.Button(btn_frame, text="🛑 Cancel Freeze", command=self.cancel_freeze).pack(side="left", padx=5)
  85.         ttk.Button(btn_frame, text="🔁 Retry", command=self.force_retry).pack(side="left", padx=5)
  86.  
  87.         # Manual correction entry
  88.         self.entry = ttk.Entry(frame)
  89.         self.btn_correct = ttk.Button(frame, text="Submit", command=lambda: self.manual_correction(self.entry.get()))
  90.         self.entry.grid(row=3, column=0, padx=5, pady=5)
  91.         self.btn_correct.grid(row=3, column=1, padx=5, pady=5)
  92.         self.entry.grid_remove()
  93.         self.btn_correct.grid_remove()
  94.  
  95.     def cancel_freeze(self):
  96.         # Exit freeze mode without logging
  97.         self.freeze_mode = False
  98.         self.pending_feedback = False
  99.         self.manual_correction_mode = False
  100.         self.entry.grid_remove()
  101.         self.btn_correct.grid_remove()
  102.         self.label_info.config(text="⏹️ Freeze canceled. Resuming live detection...")
  103.  
  104.     def force_retry(self):
  105.         # Force next retry immediately
  106.         self.last_scan_time = 0
  107.         self.label_info.config(text="🔁 Retry triggered")
  108.  
  109.     def preprocess(self, frame):
  110.         gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  111.         blur = cv2.GaussianBlur(gray, (3, 3), 0)
  112.         _, thresh = cv2.threshold(blur, 170, 255, cv2.THRESH_BINARY)
  113.         return thresh
  114.  
  115.     def predict_ocr(self, img):
  116.         data = pytesseract.image_to_data(img, config=self.ocr_config, output_type=pytesseract.Output.DICT)
  117.         for i, txt in enumerate(data['text']):
  118.             txt = txt.strip()
  119.             try:
  120.                 conf = int(data['conf'][i])
  121.             except:
  122.                 conf = 0
  123.             if re.fullmatch(r'\d{1,3}', txt) and conf >= MIN_OCR_CONF:
  124.                 return txt, conf
  125.         return None, 0
  126.  
  127.     def predict_ml(self, img):
  128.         resized = cv2.resize(img, IMG_SIZE).flatten().reshape(1, -1)
  129.         pred = self.model.predict(resized)[0]
  130.         conf = int(max(self.model.predict_proba(resized)[0]) * 100) if hasattr(self.model, 'predict_proba') else 0
  131.         return str(pred), conf
  132.  
  133.     def ensemble_predict(self, img):
  134.         # Try OCR first, then ML
  135.         num_o, conf_o = self.predict_ocr(img)
  136.         num_m, conf_m = self.predict_ml(img)
  137.         if num_o and conf_o >= conf_m:
  138.             return num_o, conf_o
  139.         return num_m, conf_m
  140.  
  141.     def save_feedback(self, correct):
  142.         if not self.current:
  143.             return
  144.         num, conf, img = self.current
  145.         ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  146.         if correct:
  147.             # Log correct feedback
  148.             with open(self.feedback_file, 'a') as f:
  149.                 f.write(f"{num},{conf},correct,{ts}\n")
  150.             folder = num
  151.             os.makedirs(os.path.join(DATASET_DIR, folder), exist_ok=True)
  152.             cv2.imwrite(os.path.join(DATASET_DIR, folder, f"{ts}.png"), img)
  153.             self.new_data.append((img, int(num)))
  154.             if len(self.new_data) >= ACTIVE_LEARNING_BATCH:
  155.                 self.retrain_model()
  156.             # Unfreeze
  157.             self.freeze_mode = False
  158.             self.pending_feedback = False
  159.             self.manual_correction_mode = False
  160.             self.entry.grid_remove()
  161.             self.btn_correct.grid_remove()
  162.             self.label_info.config(text=f"✅ Saved: {num} as correct")
  163.         else:
  164.             # On wrong, freeze and request correction
  165.             self.freeze_mode = True
  166.             self.pending_feedback = True
  167.             self.manual_correction_mode = True
  168.             self.entry.grid()
  169.             self.btn_correct.grid()
  170.             self.label_info.config(text=f"⚠️ Wrong detection. Please correct.")
  171.  
  172.     def manual_correction(self, val):
  173.         if not self.current:
  174.             return
  175.         val = val.strip()
  176.         if not (val.isdigit() or val.lower() == 'nothing'):
  177.             self.label_info.config(text="⚠️ Enter a digit or 'nothing'")
  178.             return
  179.         num_old, conf_old, img = self.current
  180.         ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  181.         folder = val if val.isdigit() else 'unlabeled'
  182.         os.makedirs(os.path.join(DATASET_DIR, folder), exist_ok=True)
  183.         cv2.imwrite(os.path.join(DATASET_DIR, folder, f"{ts}.png"), img)
  184.         with open(self.feedback_file, 'a') as f:
  185.             f.write(f"{val},{conf_old},manual,{ts}\n")
  186.         # Update current detection
  187.         self.current = (val, conf_old, img)
  188.         # Unfreeze after manual
  189.         self.freeze_mode = False
  190.         self.pending_feedback = False
  191.         self.manual_correction_mode = False
  192.         self.entry.grid_remove()
  193.         self.btn_correct.grid_remove()
  194.         self.label_info.config(text=f"📤 Correction saved: {val}")
  195.  
  196.     def undo(self):
  197.         if not self.history:
  198.             self.label_info.config(text="⟲ Nothing to undo")
  199.             return
  200.         num, ts, img = self.history.pop()
  201.         self.label_info.config(text=f"⟲ Undone: {num} at {ts}")
  202.         self.display_images(img)
  203.  
  204.     def retrain_model(self):
  205.         X, y = [], []
  206.         for label in os.listdir(DATASET_DIR):
  207.             if label == 'unlabeled': continue
  208.             for fname in os.listdir(os.path.join(DATASET_DIR, label)):
  209.                 img = cv2.imread(os.path.join(DATASET_DIR, label, fname), cv2.IMREAD_GRAYSCALE)
  210.                 if img is None: continue
  211.                 X.append(cv2.resize(img, IMG_SIZE).flatten())
  212.                 y.append(int(label))
  213.         for img_val, lbl in self.new_data:
  214.             X.append(cv2.resize(img_val, IMG_SIZE).flatten())
  215.             y.append(lbl)
  216.         if X:
  217.             self.model = RandomForestClassifier(n_estimators=100)
  218.             self.model.fit(np.array(X), np.array(y))
  219.             joblib.dump(self.model, MODEL_PATH)
  220.             self.new_data.clear()
  221.             self.label_info.config(text="🤖 Model retrained")
  222.  
  223.     def display_images(self, proc):
  224.         frame = cv2.resize(self.last_frame, (600,450))
  225.         cam_p = ImageTk.PhotoImage(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
  226.         pr_p = ImageTk.PhotoImage(Image.fromarray(proc))
  227.         self.label_cam.config(image=cam_p)
  228.         self.label_cam.imgtk = cam_p
  229.         self.label_proc.config(image=pr_p)
  230.         self.label_proc.imgtk = pr_p
  231.  
  232.     def update_frames(self):
  233.         now = time.time()
  234.         # Freeze mode: retry on last frame
  235.         if self.freeze_mode and self.current:
  236.             if now - self.last_scan_time > SCAN_INTERVAL:
  237.                 num_old, conf_old, img = self.current
  238.                 num_new, conf_new = self.ensemble_predict(img)
  239.                 self.current = (num_new, conf_new, img)
  240.                 self.label_info.config(text=f"🔁 Retry: Was {num_new} correct?")
  241.                 self.last_scan_time = now
  242.                 if self.manual_correction_mode:
  243.                     self.entry.grid()
  244.                     self.btn_correct.grid()
  245.             self.display_images(self.current[2])
  246.             self.root.after(30, self.update_frames)
  247.             return
  248.  
  249.         # Live capture mode
  250.         ret, frame = self.cap.read()
  251.         if not ret:
  252.             self.root.after(30, self.update_frames)
  253.             return
  254.  
  255.         self.last_frame = frame.copy()
  256.         proc = self.preprocess(frame)
  257.         now = time.time()
  258.  
  259.         # Skip low-content frames
  260.         if cv2.countNonZero(proc) < 500 and not self.pending_feedback:
  261.             self.label_info.config(text="📷 Weak image, no digit detected.")
  262.             self.display_images(proc)
  263.             self.root.after(30, self.update_frames)
  264.             return
  265.  
  266.         # Start new detection
  267.         if not self.pending_feedback:
  268.             num, conf = self.ensemble_predict(proc)
  269.             if num and (num != self.last_detected or now - self.last_time > COOLDOWN):
  270.                 ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  271.                 cv2.imwrite(f"detections/detected_{ts}.png", frame)
  272.                 self.history.append((num, ts, proc))
  273.                 self.current = (num, conf, proc)
  274.                 self.last_detected, self.last_time = num, now
  275.                 self.pending_feedback = True
  276.                 self.freeze_mode = True
  277.                 self.manual_correction_mode = False
  278.                 self.label_info.config(text=f"❓ Was {num} correct?")
  279.                 self.entry.delete(0, tk.END)
  280.                 self.entry.grid()
  281.                 self.btn_correct.grid()
  282.  
  283.         self.display_images(proc)
  284.         self.root.after(30, self.update_frames)
  285.  
  286. if __name__ == "__main__":
  287.     DigitDetectorApp()
  288.  
Advertisement
Add Comment
Please, Sign In to add comment