Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import cv2
- import pytesseract
- import os
- import re
- import time
- import torch
- import torch.nn as nn
- import torchvision.transforms as transforms
- from PIL import Image, ImageTk
- import tkinter as tk
- from tkinter import ttk
- # ---------------------------
- # CRNN Model Definition
- # ---------------------------
- class BidirectionalLSTM(nn.Module):
- def __init__(self, nIn, nHidden, nOut):
- super(BidirectionalLSTM, self).__init__()
- self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
- self.linear = nn.Linear(nHidden * 2, nOut)
- def forward(self, input):
- recurrent, _ = self.rnn(input)
- T, b, h = recurrent.size()
- t_rec = recurrent.view(T * b, h)
- output = self.linear(t_rec)
- output = output.view(T, b, -1)
- return output
- class CRNN(nn.Module):
- def __init__(self, imgH, nc, nclass, nh):
- super(CRNN, self).__init__()
- assert imgH % 16 == 0, 'imgH must be a multiple of 16'
- ks = [3, 3, 3, 3, 3, 3, 2]
- ps = [1, 1, 1, 1, 1, 1, 0]
- ss = [1, 1, 1, 1, 1, 1, 1]
- nm = [64, 128, 256, 256, 512, 512, 512]
- cnn = nn.Sequential()
- def convRelu(i, batchNormalization=False):
- nIn = nc if i == 0 else nm[i - 1]
- nOut = nm[i]
- cnn.add_module('conv{0}'.format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
- if batchNormalization:
- cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
- cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
- convRelu(0)
- cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x(H/2)x(W/2)
- convRelu(1)
- cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x(H/4)x(W/4)
- convRelu(2, True)
- convRelu(3)
- cnn.add_module('pooling2', nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x(H/8)x(W/4)
- convRelu(4, True)
- convRelu(5)
- cnn.add_module('pooling3', nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x(H/16)x(W/4)
- convRelu(6, True)
- self.cnn = cnn
- self.rnn = nn.Sequential(
- BidirectionalLSTM(512, nh, nh),
- BidirectionalLSTM(nh, nh, nclass))
- def forward(self, input):
- conv = self.cnn(input)
- b, c, h, w = conv.size()
- assert h == 1, "the height of conv must be 1"
- conv = conv.squeeze(2) # b x c x w
- conv = conv.permute(2, 0, 1) # w x b x c
- output = self.rnn(conv) # w x b x nclass
- return output
- # ---------------------------
- # Utility Functions
- # ---------------------------
- CHARS = '0123456789'
- NUM_CLASSES = len(CHARS) + 1 # for CTC blank
- def decode_preds(preds):
- # CTC greedy decode
- preds = preds.softmax(2)
- max_probs, idx = preds.max(2)
- idx = idx.transpose(1, 0).contiguous().view(-1)
- raw_chars = [CHARS[i] if i < len(CHARS) else '' for i in idx]
- # collapse repeats
- result = []
- prev = ''
- for ch in raw_chars:
- if ch != prev:
- result.append(ch)
- prev = ch
- return ''.join(result).replace('', '')
- # ---------------------------
- # GUI Application
- # ---------------------------
- class OCRApp:
- def __init__(self, model_path='crnn.pth'):
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- self.model = CRNN(32, 1, NUM_CLASSES, 256).to(self.device)
- if os.path.exists(model_path):
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
- self.model.eval()
- self.transform = transforms.Compose([
- transforms.Resize((32, 100)),
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,))
- ])
- self.cap = cv2.VideoCapture(0)
- self.root = tk.Tk()
- self.root.title('CRNN OCR')
- self._build_gui()
- self.pending = False
- self.last_frame = None
- self.root.after(30, self.update)
- self.root.mainloop()
- def _build_gui(self):
- self.frame_cam = ttk.Label(self.root)
- self.frame_cam.grid(row=0, column=0)
- self.frame_proc = ttk.Label(self.root)
- self.frame_proc.grid(row=0, column=1)
- self.info_label = ttk.Label(self.root, text='Waiting...')
- self.info_label.grid(row=1, column=0, columnspan=2)
- btn_frame = ttk.Frame(self.root)
- btn_frame.grid(row=2, column=0, columnspan=2)
- ttk.Button(btn_frame, text='✅', command=lambda: self.feedback(True)).pack(side='left')
- ttk.Button(btn_frame, text='❌', command=lambda: self.feedback(False)).pack(side='left')
- self.entry = ttk.Entry(self.root)
- self.submit = ttk.Button(self.root, text='Submit', command=self.manual)
- self.entry.grid(row=3, column=0)
- self.submit.grid(row=3, column=1)
- self.entry.grid_remove(); self.submit.grid_remove()
- def feedback(self, correct):
- # placeholder for active learning
- self.pending = not correct
- if not correct:
- self.entry.grid(); self.submit.grid()
- else:
- self.entry.grid_remove(); self.submit.grid_remove()
- def manual(self):
- val = self.entry.get()
- # save correction...
- self.entry.delete(0, 'end')
- self.entry.grid
Advertisement
Add Comment
Please, Sign In to add comment