Str1k3rch0

DDA 0.02

Jul 21st, 2025
6
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.52 KB | None | 0 0
  1. import cv2
  2. import os
  3. import time
  4. import torch
  5. import torch.nn as nn
  6. import torchvision.transforms as transforms
  7. from PIL import Image, ImageTk
  8. import tkinter as tk
  9. from tkinter import ttk
  10. import joblib
  11.  
  12. # ---------------------------
  13. # CRNN Model Definition
  14. # ---------------------------
  15. class BidirectionalLSTM(nn.Module):
  16. def __init__(self, nIn, nHidden, nOut):
  17. super(BidirectionalLSTM, self).__init__()
  18. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  19. self.embedding = nn.Linear(nHidden * 2, nOut)
  20.  
  21. def forward(self, input):
  22. recurrent, _ = self.rnn(input)
  23. T, b, h = recurrent.size()
  24. t_rec = recurrent.view(T * b, h)
  25. output = self.embedding(t_rec)
  26. output = output.view(T, b, -1)
  27. return output
  28.  
  29. class CRNN(nn.Module):
  30. def __init__(self, imgH, nc, nclass, nh):
  31. super(CRNN, self).__init__()
  32. ks = [3,3,3,3,3,3,2]
  33. ps = [1,1,1,1,1,1,0]
  34. ss = [1,1,1,1,1,1,1]
  35. nm = [64,128,256,256,512,512,512]
  36. cnn = nn.Sequential()
  37. def convRelu(i, batchNorm=False):
  38. nIn = nc if i==0 else nm[i-1]
  39. nOut = nm[i]
  40. cnn.add_module(f'conv{i}', nn.Conv2d(nIn,nOut,ks[i],ss[i],ps[i]))
  41. if batchNorm: cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
  42. cnn.add_module(f'relu{i}', nn.ReLU(True))
  43. convRelu(0)
  44. cnn.add_module('pool0', nn.MaxPool2d(2,2))
  45. convRelu(1)
  46. cnn.add_module('pool1', nn.MaxPool2d(2,2))
  47. convRelu(2, True); convRelu(3)
  48. cnn.add_module('pool2', nn.MaxPool2d((2,2),(2,1),(0,1)))
  49. convRelu(4, True); convRelu(5)
  50. cnn.add_module('pool3', nn.MaxPool2d((2,2),(2,1),(0,1)))
  51. convRelu(6, True)
  52. self.cnn = cnn
  53. self.rnn = nn.Sequential(
  54. BidirectionalLSTM(512, nh, nh),
  55. BidirectionalLSTM(nh, nh, nclass)
  56. )
  57.  
  58. def forward(self, x):
  59. x = self.cnn(x)
  60. b, c, h, w = x.size()
  61. assert h==1, 'height must be 1'
  62. x = x.squeeze(2).permute(2,0,1) # w x b x c
  63. x = self.rnn(x)
  64. return x
  65.  
  66. # ---------------------------
  67. # Decode & Utilities
  68. # ---------------------------
  69. CHARS = '0123456789'
  70. def ctc_decode(pred):
  71. pred = pred.softmax(2)
  72. max_prob, idx = pred.max(2)
  73. idx = idx.transpose(1,0).contiguous().view(-1)
  74. chars = []
  75. prev = -1
  76. for i in idx:
  77. if i!=prev and i<len(CHARS):
  78. chars.append(CHARS[i])
  79. prev = i
  80. return ''.join(chars)
  81.  
  82. # ---------------------------
  83. # Application
  84. # ---------------------------
  85. class OCRApp:
  86. def __init__(self, model_path='crnn.pth'):
  87. # Device
  88. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  89. # Model
  90. self.model = CRNN(32,1,len(CHARS)+1,256).to(self.device)
  91. if os.path.exists(model_path):
  92. self.model.load_state_dict(torch.load(model_path, map_location=self.device))
  93. self.model.eval()
  94. # Transform
  95. self.transform = transforms.Compose([
  96. transforms.Grayscale(),
  97. transforms.Resize((32,100)),
  98. transforms.ToTensor(),
  99. transforms.Normalize((0.5,),(0.5,))
  100. ])
  101. # Camera
  102. self.cap = cv2.VideoCapture(0)
  103. # Active Learning storage
  104. self.dataset_dir = 'dataset'
  105. os.makedirs(self.dataset_dir, exist_ok=True)
  106. self.feedback_file = 'feedback_log.csv'
  107. if not os.path.exists(self.feedback_file):
  108. open(self.feedback_file,'w').write('number,timestamp\n')
  109. # GUI
  110. self.root = tk.Tk()
  111. self.root.title('CRNN OCR')
  112. self._build_gui()
  113. # State
  114. self.freeze = False
  115. self.current_frame = None
  116. self.pred = ''
  117. # Loop
  118. self.root.after(30, self.update)
  119. self.root.mainloop()
  120.  
  121. def _build_gui(self):
  122. self.cam_label = ttk.Label(self.root)
  123. self.cam_label.grid(row=0,column=0)
  124. self.proc_label = ttk.Label(self.root)
  125. self.proc_label.grid(row=0,column=1)
  126. self.info = ttk.Label(self.root, text='Waiting...')
  127. self.info.grid(row=1,column=0,columnspan=2)
  128. btn_frame = ttk.Frame(self.root)
  129. btn_frame.grid(row=2,column=0,columnspan=2)
  130. ttk.Button(btn_frame,text='✅',command=lambda: self.on_feedback(True)).pack(side='left')
  131. ttk.Button(btn_frame,text='❌',command=lambda: self.on_feedback(False)).pack(side='left')
  132. self.entry = ttk.Entry(self.root)
  133. self.submit = ttk.Button(self.root,text='Submit',command=self.on_submit)
  134. self.entry.grid(row=3,column=0); self.submit.grid(row=3,column=1)
  135. self.entry.grid_remove(); self.submit.grid_remove()
  136.  
  137. def on_feedback(self, ok):
  138. if ok:
  139. self.freeze=False
  140. self.info.config(text=f'✅ Confirmed: {self.pred}')
  141. else:
  142. self.freeze=True
  143. self.info.config(text='❌ Wrong, enter correction:')
  144. self.entry.delete(0, 'end'); self.entry.grid(); self.submit.grid()
  145.  
  146. def on_submit(self):
  147. val = self.entry.get().strip()
  148. if val.isdigit():
  149. ts = time.strftime('%Y%m%d_%H%M%S')
  150. fname = os.path.join(self.dataset_dir,f'{ts}_{val}.png')
  151. cv2.imwrite(fname, self.current_frame)
  152. open(self.feedback_file,'a').write(f'{val},{ts}\n')
  153. self.info.config(text=f'📤 Saved correction: {val}')
  154. self.entry.grid_remove(); self.submit.grid_remove()
  155. self.freeze=False
  156.  
  157. def update(self):
  158. ret, frame = self.cap.read()
  159. if not ret:
  160. self.root.quit(); return
  161. self.current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  162. img = Image.fromarray(self.current_frame)
  163. tkimg = ImageTk.PhotoImage(image=img)
  164. self.cam_label.config(image=tkimg); self.cam_label.imgtk = tkimg
  165. if not self.freeze:
  166. # preprocess ROI: here we use full frame
  167. proc = self.transform(img).unsqueeze(0).to(self.device)
  168. with torch.no_grad():
  169. preds = self.model(proc)
  170. self.pred = ctc_decode(preds)
  171. self.info.config(text=f'Predicted: {self.pred}')
  172. proc_disp = cv2.resize(self.current_frame,(100,32))
  173. tkproc = ImageTk.PhotoImage(image=Image.fromarray(proc_disp))
  174. self.proc_label.config(image=tkproc); self.proc_label.imgtk = tkproc
  175. self.root.after(30, self.update)
  176.  
  177. if __name__=='__main__':
  178. OCRApp()
  179.  
Advertisement
Add Comment
Please, Sign In to add comment