Str1k3rch0

Digit Detector App 1.1

Jul 21st, 2025
223
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.09 KB | Software | 0 0
  1. import cv2
  2. import pytesseract
  3. import os
  4. import re
  5. import time
  6. import joblib
  7. import numpy as np
  8. from datetime import datetime
  9. import tkinter as tk
  10. from tkinter import ttk
  11. from PIL import Image, ImageTk
  12. from sklearn.ensemble import RandomForestClassifier
  13.  
  14. # Constants
  15. IMG_SIZE = (32, 32)
  16. COOLDOWN = 5  # seconds
  17. MIN_OCR_CONF = 60
  18. MIN_ML_CONF = 50
  19. ACTIVE_LEARNING_BATCH = 10
  20. SCAN_INTERVAL = 2.0  # seconds between retries
  21. MIN_WHITE_PIXELS = 500  # skip low-content frames
  22. DATASET_DIR = "dataset"
  23. MODEL_PATH = "model.pkl"
  24.  
  25. class DigitDetectorApp:
  26.     def __init__(self, model_path=MODEL_PATH):
  27.         # Load or initialize model
  28.         if os.path.exists(model_path):
  29.             self.model = joblib.load(model_path)
  30.         else:
  31.             self.model = RandomForestClassifier(n_estimators=100)
  32.         self.new_data = []
  33.  
  34.         # Tesseract config
  35.         pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
  36.         self.ocr_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789'
  37.  
  38.         # Output dirs
  39.         os.makedirs("detections", exist_ok=True)
  40.         os.makedirs(DATASET_DIR, exist_ok=True)
  41.         self.feedback_file = "feedback_log.csv"
  42.         if not os.path.exists(self.feedback_file):
  43.             with open(self.feedback_file, "w") as f:
  44.                 f.write("number,confidence,label,timestamp\n")
  45.  
  46.         # State
  47.         self.pending_feedback = False
  48.         self.freeze_mode = False
  49.         self.manual_mode = False
  50.         self.last_detected = None
  51.         self.last_time = 0
  52.         self.last_scan_time = 0
  53.         self.history = []
  54.         self.current = None
  55.         self.current_frame_raw = None
  56.         self.current_thresh = None
  57.  
  58.         # Setup GUI
  59.         self.root = tk.Tk()
  60.         self.root.title("Digit Detector App")
  61.         self.root.geometry("1280x650")
  62.         self._build_gui()
  63.  
  64.         # Video capture
  65.         self.cap = cv2.VideoCapture(0)
  66.         self.update_frames()
  67.         self.root.mainloop()
  68.  
  69.     def _build_gui(self):
  70.         frame = self.root
  71.         frame.grid_columnconfigure(0, weight=1)
  72.         frame.grid_columnconfigure(1, weight=1)
  73.         frame.grid_rowconfigure(0, weight=1)
  74.  
  75.         self.label_cam = ttk.Label(frame)
  76.         self.label_cam.grid(row=0, column=0, padx=5, pady=5, sticky="nsew")
  77.         self.label_proc = ttk.Label(frame)
  78.         self.label_proc.grid(row=0, column=1, padx=5, pady=5, sticky="nsew")
  79.  
  80.         self.label_info = ttk.Label(frame, text="🧠 Waiting for valid frame...", font=('Arial', 14))
  81.         self.label_info.grid(row=1, column=0, columnspan=2)
  82.  
  83.         btn_frame = ttk.Frame(frame)
  84.         btn_frame.grid(row=2, column=0, columnspan=2)
  85.         ttk.Button(btn_frame, text="✅ Correct", command=lambda: self.save_feedback(True)).pack(side="left", padx=5)
  86.         ttk.Button(btn_frame, text="❌ Wrong", command=lambda: self.save_feedback(False)).pack(side="left", padx=5)
  87.         ttk.Button(btn_frame, text="🛑 Cancel Freeze", command=self.cancel_freeze).pack(side="left", padx=5)
  88.         ttk.Button(btn_frame, text="🔁 Retry", command=self.force_retry).pack(side="left", padx=5)
  89.         ttk.Button(btn_frame, text="⟲ Undo", command=self.undo).pack(side="left", padx=5)
  90.  
  91.         self.entry = ttk.Entry(frame)
  92.         self.btn_submit = ttk.Button(frame, text="Submit", command=lambda: self.manual_correction(self.entry.get()))
  93.         self.entry.grid(row=3, column=0, padx=5, pady=5)
  94.         self.btn_submit.grid(row=3, column=1, padx=5, pady=5)
  95.         self.entry.grid_remove(); self.btn_submit.grid_remove()
  96.  
  97.     def cancel_freeze(self):
  98.         self.freeze_mode = False
  99.         self.pending_feedback = False
  100.         self.manual_mode = False
  101.         self.entry.grid_remove(); self.btn_submit.grid_remove()
  102.         self.label_info.config(text="âšī¸ Freeze canceled. Resuming live detection.")
  103.  
  104.     def force_retry(self):
  105.         self.last_scan_time = 0
  106.         self.label_info.config(text="🔁 Retry triggered.")
  107.  
  108.     def preprocess(self, frame):
  109.         gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  110.         blur = cv2.GaussianBlur(gray, (3,3),0)
  111.         _, thresh = cv2.threshold(blur,170,255,cv2.THRESH_BINARY)
  112.         return thresh
  113.  
  114.     def predict_ocr(self, img):
  115.         data = pytesseract.image_to_data(img, config=self.ocr_config, output_type=pytesseract.Output.DICT)
  116.         for i, txt in enumerate(data['text']):
  117.             txt = txt.strip()
  118.             try: conf = int(data['conf'][i])
  119.             except: conf = 0
  120.             if re.fullmatch(r'\d{1,3}', txt) and conf >= MIN_OCR_CONF:
  121.                 return txt, conf
  122.         return None, 0
  123.  
  124.     def predict_ml(self, img):
  125.         resized = cv2.resize(img, IMG_SIZE).flatten().reshape(1,-1)
  126.         pred = self.model.predict(resized)[0]
  127.         conf = int(max(self.model.predict_proba(resized)[0]) *100) if hasattr(self.model,'predict_proba') else 0
  128.         return str(pred), conf
  129.  
  130.     def ensemble_predict(self, img):
  131.         # Only use OCR; disable ML fallback to avoid false positives on empty or noisy frames
  132.         num_o, conf_o = self.predict_ocr(img)
  133.         # Accept only OCR predictions with sufficient confidence
  134.         if num_o and conf_o >= MIN_OCR_CONF:
  135.             return num_o, conf_o
  136.         return None, 0
  137.  
  138.     def save_feedback(self, correct):
  139.         if not self.current: return
  140.         num, conf, thresh = self.current
  141.         ts = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  142.         if correct:
  143.             # log and unfreeze
  144.             with open(self.feedback_file,'a') as f:
  145.                 f.write(f"{num},{conf},correct,{ts}\n")
  146.             folder=os.path.join(DATASET_DIR,num)
  147.             os.makedirs(folder,exist_ok=True)
  148.             cv2.imwrite(os.path.join(folder,f"{ts}.png"), self.current_frame_raw)
  149.             self.new_data.append((thresh,int(num)))
  150.             if len(self.new_data)>=ACTIVE_LEARNING_BATCH: self.retrain_model()
  151.             self.freeze_mode=False; self.pending_feedback=False; self.manual_mode=False
  152.             self.entry.grid_remove(); self.btn_submit.grid_remove()
  153.             self.label_info.config(text=f"✅ Saved correct: {num}")
  154.         else:
  155.             # freeze and enter manual mode
  156.             self.freeze_mode=True; self.pending_feedback=True; self.manual_mode=False
  157.             self.label_info.config(text=f"❌ Wrong. Will retry or await correction.")
  158.  
  159.     def manual_correction(self,val):
  160.         if not self.current: return
  161.         val=val.strip()
  162.         if not(val.isdigit() or val.lower()=='nothing'):
  163.             self.label_info.config(text="âš ī¸ Enter digit or 'nothing'")
  164.             return
  165.         num_old,conf_old,thresh=self.current
  166.         ts=datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  167.         folder=val if val.isdigit() else 'unlabeled'
  168.         path=os.path.join(DATASET_DIR,folder)
  169.         os.makedirs(path,exist_ok=True)
  170.         cv2.imwrite(os.path.join(path,f"{ts}.png"),self.current_frame_raw)
  171.         with open(self.feedback_file,'a') as f:
  172.             f.write(f"{val},{conf_old},manual,{ts}\n")
  173.         # update
  174.         self.current=(val,conf_old,self.current_thresh)
  175.         self.freeze_mode=False; self.pending_feedback=False; self.manual_mode=False
  176.         self.entry.grid_remove(); self.btn_submit.grid_remove()
  177.         self.label_info.config(text=f"📤 Correction saved: {val}")
  178.  
  179.     def undo(self):
  180.         if not self.history: return
  181.         num,ts,thresh=self.history.pop()
  182.         self.label_info.config(text=f"⟲ Undone: {num} at {ts}")
  183.         self.display_images(thresh)
  184.  
  185.     def retrain_model(self):
  186.         X,y=[],[]
  187.         for label in os.listdir(DATASET_DIR):
  188.             if label=='unlabeled': continue
  189.             for fn in os.listdir(os.path.join(DATASET_DIR,label)):
  190.                 img=cv2.imread(os.path.join(DATASET_DIR,label,fn),cv2.IMREAD_GRAYSCALE)
  191.                 if img is None: continue
  192.                 X.append(cv2.resize(img,IMG_SIZE).flatten()); y.append(int(label))
  193.         for img,lbl in self.new_data:
  194.             X.append(img.flatten()); y.append(lbl)
  195.         if X:
  196.             self.model=RandomForestClassifier(n_estimators=100)
  197.             self.model.fit(np.array(X),np.array(y))
  198.             joblib.dump(self.model,MODEL_PATH)
  199.             self.new_data.clear()
  200.             self.label_info.config(text="🤖 Model retrained.")
  201.  
  202.     def display_images(self,proc):
  203.         frame=cv2.resize(self.last_frame,(600,450))
  204.         cam=ImageTk.PhotoImage(Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)))
  205.         pr=ImageTk.PhotoImage(Image.fromarray(proc))
  206.         self.label_cam.config(image=cam); self.label_cam.imgtk=cam
  207.         self.label_proc.config(image=pr); self.label_proc.imgtk=pr
  208.  
  209.     def update_frames(self):
  210.         now=time.time()
  211.         # Freeze & retry
  212.         if self.freeze_mode and self.current and not self.manual_mode:
  213.             if now-self.last_scan_time>SCAN_INTERVAL:
  214.                 # retry detection on saved thresh
  215.                 num,conf=self.ensemble_predict(self.current_thresh)
  216.                 self.current=(num,conf,self.current_thresh)
  217.                 self.label_info.config(text=f"🔁 Retry: Was {num} correct?")
  218.                 self.last_scan_time=now
  219.                 # after first retry allow manual
  220.                 self.manual_mode=True
  221.                 self.entry.grid(); self.btn_submit.grid()
  222.             self.display_images(self.current[2])
  223.             self.root.after(30,self.update_frames)
  224.             return
  225.  
  226.         # live capture
  227.         ret,frame=self.cap.read()
  228.         if not ret:
  229.             self.root.after(30,self.update_frames); return
  230.         self.last_frame=frame
  231.         proc=self.preprocess(frame)
  232.         self.current_frame_raw=frame.copy()
  233.         self.current_thresh=proc
  234.  
  235.         # skip low-content
  236.         if cv2.countNonZero(proc)<MIN_WHITE_PIXELS:
  237.             self.label_info.config(text="📷 Weak image, skipping.")
  238.             self.display_images(proc)
  239.             self.root.after(30,self.update_frames)
  240.             return
  241.  
  242.         # initial detection
  243.         if not self.pending_feedback:
  244.             # validate detection
  245.             num_o,conf_o=self.predict_ocr(proc)
  246.             num_m,conf_m=self.predict_ml(proc)
  247.             valid_o=num_o and conf_o>=MIN_OCR_CONF
  248.             valid_m=num_m and conf_m>=MIN_ML_CONF
  249.             if not(valid_o or valid_m):
  250.                 self.root.after(30,self.update_frames); return
  251.             num,conf=(num_o,conf_o) if valid_o and conf_o>=conf_m else (num_m,conf_m)
  252.             # freeze and prompt
  253.             ts=datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  254.             cv2.imwrite(f"detections/detected_{ts}.png",frame)
  255.             self.history.append((num,ts,proc))
  256.             self.current=(num,conf,proc)
  257.             self.last_detected, self.last_time = num, now
  258.             self.pending_feedback=True; self.freeze_mode=True; self.manual_mode=False
  259.             self.label_info.config(text=f"❓ Was {num} correct?")
  260.             self.entry.delete(0,tk.END); self.entry.grid(); self.btn_submit.grid()
  261.  
  262.         self.display_images(proc)
  263.         self.root.after(30,self.update_frames)
  264.  
  265. if __name__=="__main__":
  266.     DigitDetectorApp()
  267.  
Advertisement
Add Comment
Please, Sign In to add comment