Str1k3rch0

Digit Detector App

Jul 21st, 2025
213
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.68 KB | Software | 0 0
  1. import cv2
  2. import pytesseract
  3. import os
  4. import re
  5. import time
  6. import joblib
  7. import numpy as np
  8. from datetime import datetime
  9. import tkinter as tk
  10. from tkinter import ttk
  11. from PIL import Image, ImageTk
  12. from sklearn.ensemble import RandomForestClassifier
  13.  
  14. # Constants
  15. IMG_SIZE = (32, 32)
  16. COOLDOWN = 5  # seconds
  17. MIN_OCR_CONF = 60
  18. ACTIVE_LEARNING_BATCH = 10
  19. SCAN_INTERVAL = 1.0  # seconds between rescans of frozen frame
  20. MIN_WHITE_PIXELS = 500  # minimum pixels to consider frame valid
  21. DATASET_DIR = "dataset"
  22. MODEL_PATH = "model.pkl"
  23.  
  24. class DigitDetectorApp:
  25.     def __init__(self, model_path=MODEL_PATH):
  26.         if os.path.exists(model_path):
  27.             self.model = joblib.load(model_path)
  28.         else:
  29.             self.model = RandomForestClassifier(n_estimators=100)
  30.         self.new_data = []
  31.         pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
  32.         self.ocr_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789'
  33.         os.makedirs("detections", exist_ok=True)
  34.         os.makedirs(DATASET_DIR, exist_ok=True)
  35.         self.feedback_file = "feedback_log.csv"
  36.         if not os.path.exists(self.feedback_file):
  37.             with open(self.feedback_file, "w") as f:
  38.                 f.write("number,confidence,label,timestamp\n")
  39.         self.pending_feedback = False
  40.         self.freeze_mode = False
  41.         self.last_detected = None
  42.         self.last_time = 0
  43.         self.last_scan_time = 0
  44.         self.history = []
  45.         self.root = tk.Tk()
  46.         self.root.title("Digit Detector App")
  47.         self.root.geometry("1280x650")
  48.         self._build_gui()
  49.         self.cap = cv2.VideoCapture(0)
  50.         self.update_frames()
  51.         self.root.mainloop()
  52.  
  53.     def _build_gui(self):
  54.         frame = self.root
  55.         frame.grid_columnconfigure(0, weight=1)
  56.         frame.grid_columnconfigure(1, weight=1)
  57.         frame.grid_rowconfigure(0, weight=1)
  58.         self.label_cam = ttk.Label(frame)
  59.         self.label_cam.grid(row=0, column=0, padx=5, pady=5, sticky="nsew")
  60.         self.label_proc = ttk.Label(frame)
  61.         self.label_proc.grid(row=0, column=1, padx=5, pady=5, sticky="nsew")
  62.         self.label_info = ttk.Label(frame, text="🧠 Waiting...", font=('Arial', 14))
  63.         self.label_info.grid(row=1, column=0, columnspan=2)
  64.         btn_frame = ttk.Frame(frame)
  65.         btn_frame.grid(row=2, column=0, columnspan=2)
  66.         ttk.Button(btn_frame, text="✅ Correct", command=lambda: self.save_feedback(True)).pack(side="left", padx=5)
  67.         ttk.Button(btn_frame, text="❌ Wrong", command=lambda: self.save_feedback(False)).pack(side="left", padx=5)
  68.         ttk.Button(btn_frame, text="⟲ Undo", command=self.undo).pack(side="left", padx=5)
  69.         ttk.Button(btn_frame, text="🛑 Cancel Freeze", command=self.cancel_freeze).pack(side="left", padx=5)
  70.         ttk.Button(btn_frame, text="đŸŒĄī¸ Heatmap", command=self.toggle_heatmap).pack(side="left", padx=5)
  71.         self.entry = ttk.Entry(frame)
  72.         self.btn_correct = ttk.Button(frame, text="Submit", command=lambda: self.manual_correction(self.entry.get()))
  73.         self.entry.grid(row=3, column=0, padx=5, pady=5)
  74.         self.btn_correct.grid(row=3, column=1, padx=5, pady=5)
  75.         self.entry.grid_remove(); self.btn_correct.grid_remove()
  76.         self.show_heat = False
  77.  
  78.     def cancel_freeze(self):
  79.         self.freeze_mode = False
  80.         self.pending_feedback = False
  81.         self.entry.grid_remove(); self.btn_correct.grid_remove()
  82.         self.label_info.config(text="âšī¸ Scan canceled. Resuming live detection...")
  83.  
  84.     def preprocess(self, frame):
  85.         gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  86.         blur = cv2.GaussianBlur(gray, (3, 3), 0)
  87.         _, thresh = cv2.threshold(blur, 170, 255, cv2.THRESH_BINARY)
  88.         return thresh
  89.  
  90.     def predict_ocr(self, img):
  91.         data = pytesseract.image_to_data(img, config=self.ocr_config, output_type=pytesseract.Output.DICT)
  92.         for i, txt in enumerate(data['text']):
  93.             txt = txt.strip()
  94.             try: conf = int(data['conf'][i])
  95.             except: conf = 0
  96.             if re.fullmatch(r'\d{1,3}', txt) and conf >= MIN_OCR_CONF:
  97.                 return txt, conf
  98.         return None, 0
  99.  
  100.     def predict_ml(self, img):
  101.         resized = cv2.resize(img, IMG_SIZE).flatten().reshape(1, -1)
  102.         pred = self.model.predict(resized)[0]
  103.         conf = int(max(self.model.predict_proba(resized)[0]) * 100) if hasattr(self.model, 'predict_proba') else 0
  104.         return str(pred), conf
  105.  
  106.     def ensemble_predict(self, img):
  107.         if cv2.countNonZero(img) < MIN_WHITE_PIXELS:
  108.             return None, 0
  109.         num_o, conf_o = self.predict_ocr(img)
  110.         num_m, conf_m = self.predict_ml(img)
  111.         if num_o and conf_o >= conf_m:
  112.             return num_o, conf_o
  113.         return num_m, conf_m
  114.  
  115.     def toggle_heatmap(self):
  116.         self.show_heat = not self.show_heat
  117.  
  118.     def save_feedback(self, correct):
  119.         if not hasattr(self, 'current') or self.current is None:
  120.             return
  121.         num, conf, img = self.current
  122.         ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  123.         if correct:
  124.             with open(self.feedback_file, 'a') as f:
  125.                 f.write(f"{num},{conf},correct,{ts}\n")
  126.             folder = num
  127.             path = os.path.join(DATASET_DIR, folder)
  128.             os.makedirs(path, exist_ok=True)
  129.             cv2.imwrite(os.path.join(path, f"{ts}.png"), img)
  130.             self.new_data.append((img, int(num)))
  131.             if len(self.new_data) >= ACTIVE_LEARNING_BATCH:
  132.                 self.retrain_model()
  133.             self.pending_feedback = False
  134.             self.freeze_mode = False
  135.             self.entry.grid_remove(); self.btn_correct.grid_remove()
  136.             self.label_info.config(text=f"✅ Saved: {num} as correct")
  137.         else:
  138.             with open(self.feedback_file, 'a') as f:
  139.                 f.write(f"{num},{conf},wrong,{ts}\n")
  140.             self.pending_feedback = True
  141.             self.freeze_mode = True
  142.             self.entry.grid(); self.btn_correct.grid()
  143.             self.label_info.config(text="âš ī¸ Marked wrong. Please correct the number.")
  144.  
  145.     def manual_correction(self, val):
  146.         val = val.strip()
  147.         if not (val.isdigit() or val.lower()=='nothing'):
  148.             self.label_info.config(text="âš ī¸ Enter valid number or 'nothing'")
  149.             return
  150.         num_old, conf_old, img = self.current
  151.         ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  152.         folder = val if val.isdigit() else 'unlabeled'
  153.         path = os.path.join(DATASET_DIR, folder)
  154.         os.makedirs(path, exist_ok=True)
  155.         cv2.imwrite(os.path.join(path, f"{ts}.png"), img)
  156.         with open(self.feedback_file, 'a') as f:
  157.             f.write(f"{val},{conf_old},manual,{ts}\n")
  158.         self.pending_feedback = False
  159.         self.freeze_mode = False
  160.         self.entry.grid_remove(); self.btn_correct.grid_remove()
  161.         self.label_info.config(text=f"📤 Correction saved: {val}")
  162.  
  163.     def undo(self):
  164.         if not self.history:
  165.             self.label_info.config(text="⟲ Nothing to undo")
  166.             return
  167.         num, ts, img = self.history.pop()
  168.         self.label_info.config(text=f"⟲ Undone: {num} at {ts}")
  169.         self.display_images(img)
  170.  
  171.     def retrain_model(self):
  172.         X, y = [], []
  173.         for label in os.listdir(DATASET_DIR):
  174.             if label == 'unlabeled': continue
  175.             for fname in os.listdir(os.path.join(DATASET_DIR, label)):
  176.                 img = cv2.imread(os.path.join(DATASET_DIR, label, fname), cv2.IMREAD_GRAYSCALE)
  177.                 resized = cv2.resize(img, IMG_SIZE).flatten()
  178.                 X.append(resized); y.append(int(label))
  179.         for img_val, lbl in self.new_data:
  180.             resized = cv2.resize(img_val, IMG_SIZE).flatten()
  181.             X.append(resized); y.append(lbl)
  182.         if X:
  183.             self.model = RandomForestClassifier(n_estimators=100)
  184.             self.model.fit(np.array(X), np.array(y))
  185.             joblib.dump(self.model, MODEL_PATH)
  186.             self.new_data.clear()
  187.             self.label_info.config(text="🤖 Model retrained with feedback")
  188.  
  189.     def display_images(self, proc):
  190.         frame = cv2.resize(self.last_frame, (600,450))
  191.         img_proc = proc
  192.         if self.show_heat:
  193.             img_proc = cv2.applyColorMap(proc, cv2.COLORMAP_JET)
  194.         cam_p = ImageTk.PhotoImage(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
  195.         pr_p = ImageTk.PhotoImage(Image.fromarray(img_proc))
  196.         self.label_cam.config(image=cam_p); self.label_cam.imgtk = cam_p
  197.         self.label_proc.config(image=pr_p); self.label_proc.imgtk = pr_p
  198.  
  199.     def update_frames(self):
  200.         if self.freeze_mode and hasattr(self, 'current') and self.current:
  201.             now = time.time()
  202.             if now - self.last_scan_time > SCAN_INTERVAL:
  203.                 num_old, conf_old, img = self.current
  204.                 num_new, conf_new = self.ensemble_predict(img)
  205.                 if num_new and num_new != num_old:
  206.                     self.current = (num_new, conf_new, img)
  207.                     self.label_info.config(text=f"🔁 Retry: Was {num_new} correct?")
  208.                 elif num_new is None:
  209.                     self.label_info.config(text="📷 Weak image, no digit detected.")
  210.                 self.last_scan_time = now
  211.             self.display_images(self.current[2])
  212.             self.root.after(30, self.update_frames)
  213.             return
  214.  
  215.         ret, frame = self.cap.read()
  216.         if not ret:
  217.             self.root.after(30, self.update_frames); return
  218.         self.last_frame = frame.copy(); proc = self.preprocess(frame); now = time.time()
  219.         if cv2.countNonZero(proc) < MIN_WHITE_PIXELS:
  220.             self.label_info.config(text="📷 Weak image, no digit detected.")
  221.             self.display_images(proc); self.root.after(30, self.update_frames); return
  222.         if not self.pending_feedback and not self.freeze_mode:
  223.             num, conf = self.ensemble_predict(proc)
  224.             if num and (num != self.last_detected or now - self.last_time > COOLDOWN):
  225.                 ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  226.                 cv2.imwrite(f"detections/detected_{ts}.png", frame)
  227.                 self.history.append((num, ts, proc))
  228.                 self.current = (num, conf, proc)
  229.                 self.last_detected, self.last_time = num, now
  230.                 self.pending_feedback = True; self.freeze_mode = True
  231.                 self.label_info.config(text=f"❓ Was {num} correct?")
  232.                 self.entry.delete(0, tk.END); self.entry.grid(); self.btn_correct.grid()
  233.         self.display_images(proc); self.root.after(30, self.update_frames)
  234.  
  235. if __name__ == "__main__":
  236.     DigitDetectorApp()
  237.  
Advertisement
Add Comment
Please, Sign In to add comment