Str1k3rch0

DDA 0.01

Jul 21st, 2025
214
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.27 KB | Software | 0 0
  1. import cv2
  2. import pytesseract
  3. import os
  4. import re
  5. import time
  6. import torch
  7. import torch.nn as nn
  8. import torchvision.transforms as transforms
  9. from PIL import Image, ImageTk
  10. import tkinter as tk
  11. from tkinter import ttk
  12.  
  13. # ---------------------------
  14. # CRNN Model Definition
  15. # ---------------------------
  16. class BidirectionalLSTM(nn.Module):
  17.     def __init__(self, nIn, nHidden, nOut):
  18.         super(BidirectionalLSTM, self).__init__()
  19.         self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  20.         self.linear = nn.Linear(nHidden * 2, nOut)
  21.  
  22.     def forward(self, input):
  23.         recurrent, _ = self.rnn(input)
  24.         T, b, h = recurrent.size()
  25.         t_rec = recurrent.view(T * b, h)
  26.         output = self.linear(t_rec)
  27.         output = output.view(T, b, -1)
  28.         return output
  29.  
  30. class CRNN(nn.Module):
  31.     def __init__(self, imgH, nc, nclass, nh):
  32.         super(CRNN, self).__init__()
  33.         assert imgH % 16 == 0, 'imgH must be a multiple of 16'
  34.         ks = [3, 3, 3, 3, 3, 3, 2]
  35.         ps = [1, 1, 1, 1, 1, 1, 0]
  36.         ss = [1, 1, 1, 1, 1, 1, 1]
  37.         nm = [64, 128, 256, 256, 512, 512, 512]
  38.  
  39.         cnn = nn.Sequential()
  40.         def convRelu(i, batchNormalization=False):
  41.             nIn = nc if i == 0 else nm[i - 1]
  42.             nOut = nm[i]
  43.             cnn.add_module('conv{0}'.format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  44.             if batchNormalization:
  45.                 cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  46.             cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  47.  
  48.         convRelu(0)
  49.         cnn.add_module('pooling0', nn.MaxPool2d(2, 2))  # 64x(H/2)x(W/2)
  50.         convRelu(1)
  51.         cnn.add_module('pooling1', nn.MaxPool2d(2, 2))  # 128x(H/4)x(W/4)
  52.         convRelu(2, True)
  53.         convRelu(3)
  54.         cnn.add_module('pooling2', nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x(H/8)x(W/4)
  55.         convRelu(4, True)
  56.         convRelu(5)
  57.         cnn.add_module('pooling3', nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x(H/16)x(W/4)
  58.         convRelu(6, True)
  59.  
  60.         self.cnn = cnn
  61.         self.rnn = nn.Sequential(
  62.             BidirectionalLSTM(512, nh, nh),
  63.             BidirectionalLSTM(nh, nh, nclass))
  64.  
  65.     def forward(self, input):
  66.         conv = self.cnn(input)
  67.         b, c, h, w = conv.size()
  68.         assert h == 1, "the height of conv must be 1"
  69.         conv = conv.squeeze(2)  # b x c x w
  70.         conv = conv.permute(2, 0, 1)  # w x b x c
  71.         output = self.rnn(conv)  # w x b x nclass
  72.         return output
  73.  
  74. # ---------------------------
  75. # Utility Functions
  76. # ---------------------------
  77. CHARS = '0123456789'
  78. NUM_CLASSES = len(CHARS) + 1  # for CTC blank
  79.  
  80. def decode_preds(preds):
  81.     # CTC greedy decode
  82.     preds = preds.softmax(2)
  83.     max_probs, idx = preds.max(2)
  84.     idx = idx.transpose(1, 0).contiguous().view(-1)
  85.     raw_chars = [CHARS[i] if i < len(CHARS) else '' for i in idx]
  86.     # collapse repeats
  87.     result = []
  88.     prev = ''
  89.     for ch in raw_chars:
  90.         if ch != prev:
  91.             result.append(ch)
  92.             prev = ch
  93.     return ''.join(result).replace('', '')
  94.  
  95. # ---------------------------
  96. # GUI Application
  97. # ---------------------------
  98. class OCRApp:
  99.     def __init__(self, model_path='crnn.pth'):
  100.         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  101.         self.model = CRNN(32, 1, NUM_CLASSES, 256).to(self.device)
  102.         if os.path.exists(model_path):
  103.             self.model.load_state_dict(torch.load(model_path, map_location=self.device))
  104.         self.model.eval()
  105.  
  106.         self.transform = transforms.Compose([
  107.             transforms.Resize((32, 100)),
  108.             transforms.ToTensor(),
  109.             transforms.Normalize((0.5,), (0.5,))
  110.         ])
  111.  
  112.         self.cap = cv2.VideoCapture(0)
  113.         self.root = tk.Tk()
  114.         self.root.title('CRNN OCR')
  115.         self._build_gui()
  116.         self.pending = False
  117.         self.last_frame = None
  118.         self.root.after(30, self.update)
  119.         self.root.mainloop()
  120.  
  121.     def _build_gui(self):
  122.         self.frame_cam = ttk.Label(self.root)
  123.         self.frame_cam.grid(row=0, column=0)
  124.         self.frame_proc = ttk.Label(self.root)
  125.         self.frame_proc.grid(row=0, column=1)
  126.  
  127.         self.info_label = ttk.Label(self.root, text='Waiting...')
  128.         self.info_label.grid(row=1, column=0, columnspan=2)
  129.  
  130.         btn_frame = ttk.Frame(self.root)
  131.         btn_frame.grid(row=2, column=0, columnspan=2)
  132.         ttk.Button(btn_frame, text='✅', command=lambda: self.feedback(True)).pack(side='left')
  133.         ttk.Button(btn_frame, text='❌', command=lambda: self.feedback(False)).pack(side='left')
  134.         self.entry = ttk.Entry(self.root)
  135.         self.submit = ttk.Button(self.root, text='Submit', command=self.manual)
  136.         self.entry.grid(row=3, column=0)
  137.         self.submit.grid(row=3, column=1)
  138.         self.entry.grid_remove(); self.submit.grid_remove()
  139.  
  140.     def feedback(self, correct):
  141.         # placeholder for active learning
  142.         self.pending = not correct
  143.         if not correct:
  144.             self.entry.grid(); self.submit.grid()
  145.         else:
  146.             self.entry.grid_remove(); self.submit.grid_remove()
  147.  
  148.     def manual(self):
  149.         val = self.entry.get()
  150.         # save correction...
  151.         self.entry.delete(0, 'end')
  152.         self.entry.grid
Advertisement
Add Comment
Please, Sign In to add comment