Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import cv2
- import os
- import time
- import torch
- import torch.nn as nn
- import torchvision.transforms as transforms
- from PIL import Image, ImageTk
- import tkinter as tk
- from tkinter import ttk
- import joblib
- # ---------------------------
- # CRNN Model Definition
- # ---------------------------
- class BidirectionalLSTM(nn.Module):
- def __init__(self, nIn, nHidden, nOut):
- super(BidirectionalLSTM, self).__init__()
- self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
- self.embedding = nn.Linear(nHidden * 2, nOut)
- def forward(self, input):
- recurrent, _ = self.rnn(input)
- T, b, h = recurrent.size()
- t_rec = recurrent.view(T * b, h)
- output = self.embedding(t_rec)
- output = output.view(T, b, -1)
- return output
- class CRNN(nn.Module):
- def __init__(self, imgH, nc, nclass, nh):
- super(CRNN, self).__init__()
- ks = [3,3,3,3,3,3,2]
- ps = [1,1,1,1,1,1,0]
- ss = [1,1,1,1,1,1,1]
- nm = [64,128,256,256,512,512,512]
- cnn = nn.Sequential()
- def convRelu(i, batchNorm=False):
- nIn = nc if i==0 else nm[i-1]
- nOut = nm[i]
- cnn.add_module(f'conv{i}', nn.Conv2d(nIn,nOut,ks[i],ss[i],ps[i]))
- if batchNorm: cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
- cnn.add_module(f'relu{i}', nn.ReLU(True))
- convRelu(0)
- cnn.add_module('pool0', nn.MaxPool2d(2,2))
- convRelu(1)
- cnn.add_module('pool1', nn.MaxPool2d(2,2))
- convRelu(2, True); convRelu(3)
- cnn.add_module('pool2', nn.MaxPool2d((2,2),(2,1),(0,1)))
- convRelu(4, True); convRelu(5)
- cnn.add_module('pool3', nn.MaxPool2d((2,2),(2,1),(0,1)))
- convRelu(6, True)
- self.cnn = cnn
- self.rnn = nn.Sequential(
- BidirectionalLSTM(512, nh, nh),
- BidirectionalLSTM(nh, nh, nclass)
- )
- def forward(self, x):
- x = self.cnn(x)
- b, c, h, w = x.size()
- assert h==1, 'height must be 1'
- x = x.squeeze(2).permute(2,0,1) # w x b x c
- x = self.rnn(x)
- return x
- # ---------------------------
- # Decode & Utilities
- # ---------------------------
- CHARS = '0123456789'
- def ctc_decode(pred):
- pred = pred.softmax(2)
- max_prob, idx = pred.max(2)
- idx = idx.transpose(1,0).contiguous().view(-1)
- chars = []
- prev = -1
- for i in idx:
- if i!=prev and i<len(CHARS):
- chars.append(CHARS[i])
- prev = i
- return ''.join(chars)
- # ---------------------------
- # Application
- # ---------------------------
- class OCRApp:
- def __init__(self, model_path='crnn.pth'):
- # Device
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # Model
- self.model = CRNN(32,1,len(CHARS)+1,256).to(self.device)
- if os.path.exists(model_path):
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
- self.model.eval()
- # Transform
- self.transform = transforms.Compose([
- transforms.Grayscale(),
- transforms.Resize((32,100)),
- transforms.ToTensor(),
- transforms.Normalize((0.5,),(0.5,))
- ])
- # Camera
- self.cap = cv2.VideoCapture(0)
- # Active Learning storage
- self.dataset_dir = 'dataset'
- os.makedirs(self.dataset_dir, exist_ok=True)
- self.feedback_file = 'feedback_log.csv'
- if not os.path.exists(self.feedback_file):
- open(self.feedback_file,'w').write('number,timestamp\n')
- # GUI
- self.root = tk.Tk()
- self.root.title('CRNN OCR')
- self._build_gui()
- # State
- self.freeze = False
- self.current_frame = None
- self.pred = ''
- # Loop
- self.root.after(30, self.update)
- self.root.mainloop()
- def _build_gui(self):
- self.cam_label = ttk.Label(self.root)
- self.cam_label.grid(row=0,column=0)
- self.proc_label = ttk.Label(self.root)
- self.proc_label.grid(row=0,column=1)
- self.info = ttk.Label(self.root, text='Waiting...')
- self.info.grid(row=1,column=0,columnspan=2)
- btn_frame = ttk.Frame(self.root)
- btn_frame.grid(row=2,column=0,columnspan=2)
- ttk.Button(btn_frame,text='✅',command=lambda: self.on_feedback(True)).pack(side='left')
- ttk.Button(btn_frame,text='❌',command=lambda: self.on_feedback(False)).pack(side='left')
- self.entry = ttk.Entry(self.root)
- self.submit = ttk.Button(self.root,text='Submit',command=self.on_submit)
- self.entry.grid(row=3,column=0); self.submit.grid(row=3,column=1)
- self.entry.grid_remove(); self.submit.grid_remove()
- def on_feedback(self, ok):
- if ok:
- self.freeze=False
- self.info.config(text=f'✅ Confirmed: {self.pred}')
- else:
- self.freeze=True
- self.info.config(text='❌ Wrong, enter correction:')
- self.entry.delete(0, 'end'); self.entry.grid(); self.submit.grid()
- def on_submit(self):
- val = self.entry.get().strip()
- if val.isdigit():
- ts = time.strftime('%Y%m%d_%H%M%S')
- fname = os.path.join(self.dataset_dir,f'{ts}_{val}.png')
- cv2.imwrite(fname, self.current_frame)
- open(self.feedback_file,'a').write(f'{val},{ts}\n')
- self.info.config(text=f'📤 Saved correction: {val}')
- self.entry.grid_remove(); self.submit.grid_remove()
- self.freeze=False
- def update(self):
- ret, frame = self.cap.read()
- if not ret:
- self.root.quit(); return
- self.current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
- img = Image.fromarray(self.current_frame)
- tkimg = ImageTk.PhotoImage(image=img)
- self.cam_label.config(image=tkimg); self.cam_label.imgtk = tkimg
- if not self.freeze:
- # preprocess ROI: here we use full frame
- proc = self.transform(img).unsqueeze(0).to(self.device)
- with torch.no_grad():
- preds = self.model(proc)
- self.pred = ctc_decode(preds)
- self.info.config(text=f'Predicted: {self.pred}')
- proc_disp = cv2.resize(self.current_frame,(100,32))
- tkproc = ImageTk.PhotoImage(image=Image.fromarray(proc_disp))
- self.proc_label.config(image=tkproc); self.proc_label.imgtk = tkproc
- self.root.after(30, self.update)
- if __name__=='__main__':
- OCRApp()
Advertisement
Add Comment
Please, Sign In to add comment