Advertisement
Guest User

Untitled

a guest
Nov 20th, 2017
172
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.00 KB | None | 0 0
  1. import sys
  2. import os, cv2
  3. import matplotlib.pyplot as plt
  4. from skimage import feature
  5. import numpy as np
  6. from sklearn.preprocessing import StandardScaler
  7. from sklearn.svm import LinearSVC
  8. from sklearn.cross_validation import train_test_split
  9. import random
  10. import time
  11. import threading
  12.  
  13. path = './DATASET/'
  14. hog_chanel = 'ALL'
  15. spacial_size = (32, 32)
  16. hist_bins = 32
  17. spatial_feat = True
  18. hist_feat = True
  19. hog_feat = True
  20. y_start_stop =[400, 650]
  21.  
  22. listPositiveHog =[]
  23. listNegativeHog =[]
  24.  
  25. #[300,300] # [600,600]
  26.  
  27. class OpenCVFrame():
  28.     imagePreparation = None #SVMFeature
  29.     binaryClassifier = None
  30.  
  31.     frame = None
  32.     refPt = []
  33.  
  34.     def __init__(self,svc,scaler,imagePreparation):
  35.         self.binaryClassifier = BinaryClassifier(svc, scaler)
  36.         self.imagePreparation= imagePreparation
  37.  
  38.     # def __init__(self):
  39.     #     pass
  40.  
  41.     def prepareImage(self,img):
  42.         # self.imagePreparation.image_feature(img)
  43.         return self.imagePreparation.image_feature(img)
  44.  
  45.     def ifIsCar(self, img):
  46.         features = self.prepareImage(img)
  47.         if (self.binaryClassifier.predict(features)):
  48.             return 1
  49.         else:
  50.             return 0
  51.     def click_and_crop(self, event, x, y, flags, param):
  52.         if event == cv2.EVENT_LBUTTONDOWN:
  53.             self.refPt = [x, y]
  54.         elif event == cv2.EVENT_LBUTTONUP:
  55.  
  56.             # draw a rectangle around the region of interest
  57.             cv2.imshow('Frame', self.frame)
  58.  
  59.     def slide_window_helper(self,mat, xStartStop = [None,None], yStartStop = [None,None],step=16,rectSd=128):
  60.  
  61.         pixYStart = yStartStop[0]
  62.         pixYStop = yStartStop[1]
  63.         pixXStop = xStartStop[1]
  64.  
  65.         rectangleBoxList = []
  66.  
  67.         while pixYStart <= pixYStop:
  68.             pixXStart = xStartStop[0]
  69.             while pixXStart <= pixXStop:
  70.                 rectangleBoxList.append([(pixXStart,pixYStart),(pixXStart+rectSd,pixYStart+rectSd)])
  71.                 pixXStart+=step
  72.             pixYStart+=step
  73.  
  74.         return rectangleBoxList
  75.  
  76.  
  77.     def slide_widow(self,mat, xStarStop = [None,None],yStartStop = [None,None]):
  78.         windows_a = self.slide_window_helper(mat, xStarStop, yStartStop)
  79.         # windows_b = self.slide_window_helper(mat, xStarStop, yStartStop,step = 8, rectSd=64)
  80.  
  81.         return windows_a#+windows_b
  82.  
  83.     def drawBox(self,rectFrame,boxListToDraw):
  84.         roi = None
  85.         for i in boxListToDraw:
  86.             roi = rectFrame[i[0][1]:i[1][1], i[0][0]:i[1][0]]
  87.             if(self.ifIsCar(roi)):
  88.                 cv2.rectangle(rectFrame, i[0], i[1], (0, 255, 0), 2)
  89.         return rectFrame
  90.  
  91.     def runWindowSlider(self):
  92.         cap = cv2.VideoCapture('car1.avi')
  93.         rectangleSide = 64
  94.         boxListToDraw = []
  95.         if (cap.isOpened() == False):
  96.             print("Error opening video stream or file")
  97.         else:
  98.             while (cap.isOpened()):
  99.                 ret, self.frame = cap.read()
  100.                 boxListToDraw = self.slide_widow(self.frame,xStarStop = [200,600],yStartStop = [200,400])
  101.  
  102.                 rectFrame = np.copy(self.frame)# TODO dać tutaj wątek
  103.  
  104.                 rectFrame = self.drawBox(rectFrame,boxListToDraw)
  105.                 cv2.imshow('Frame', rectFrame)
  106.                 if cv2.waitKey(1) & 0xFF == ord('q'):
  107.                     break;
  108.             cap.release()
  109.             cv2.destroyAllWindows()
  110.  
  111.     def runShoter(self):
  112.         cap = cv2.VideoCapture('car1.avi')
  113.         rectangleSide = 64
  114.  
  115.         if (cap.isOpened() == False):
  116.             print("Error opening video stream or file")
  117.         else:
  118.             while (cap.isOpened()):
  119.                 ret, self.frame = cap.read()
  120.                 clone = self.frame.copy()
  121.                 if ret == True:
  122.                     cv2.setMouseCallback('Frame', self.click_and_crop)
  123.                     if len(self.refPt) == 2:
  124.                         roi = clone[self.refPt[1] - rectangleSide:self.refPt[1] + rectangleSide,
  125.                               self.refPt[0] - rectangleSide:self.refPt[0] + rectangleSide]
  126.                         cv2.rectangle(self.frame, (self.refPt[0] - rectangleSide, self.refPt[1] - rectangleSide),
  127.                                       (self.refPt[0] + rectangleSide, self.refPt[1] + rectangleSide), (0, 255, 0), 2)
  128.                         mat  = PhotoService.resizeIMGShoot(roi)
  129.                         self.ifIsCar(mat)
  130.                         self.refPt = []
  131.                         cv2.waitKey(500)
  132.                     cv2.imshow('Frame', self.frame)
  133.                     if cv2.waitKey(30) & 0xFF == ord('q'):
  134.                         break
  135.                 else:
  136.                     break
  137.  
  138.  
  139.             cap.release()
  140.             cv2.destroyAllWindows()
  141.  
  142.  
  143.  
  144.  
  145. class PhotoService:
  146.     global path
  147.     @staticmethod
  148.     def getPhotos():
  149.         listFiles = {}
  150.         for root, dirs, files in os.walk(path):
  151.             if len(files) == 0:
  152.                 continue
  153.             else:
  154.                 listFiles[root] = files
  155.         return listFiles
  156.  
  157.     @staticmethod
  158.     def getUrlPhotos():
  159.         photos = PhotoService.getPhotos()
  160.         photoList = []
  161.         for root in photos:
  162.             photoListTemp = []
  163.             for file in photos[root]:
  164.                 photoListTemp.append(root + "/" + file)
  165.             photoList.append(photoListTemp)
  166.         return photoList
  167.  
  168.     @staticmethod
  169.     def getPositiveIndex(pathFile):
  170.         if "NEGATIVE" in pathFile:
  171.             return [1,0]
  172.         else:
  173.             return [0,1]
  174.  
  175.     @staticmethod
  176.     def getImageIMG(path):
  177.         img = cv2.imread(path)
  178.         img = cv2.resize(img,(64,64))
  179.         # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  180.         return img
  181.  
  182.     @staticmethod
  183.     def resizeIMGShoot(img, size=(64,64)):
  184.         img = cv2.resize(img,size)
  185.         # cv2.imshow("ROI", img)
  186.         return img
  187.  
  188.  
  189. class SVMFeature():
  190.  
  191.     def __init__(self):
  192.         pass
  193.  
  194.     def imResize(self, img, imgSize = (64,64)):
  195.         return cv2.resize(img,imgSize)
  196.  
  197.     def computeHog(self,img,vis = False ,orient =9, cellSize = 8, cellPerBlock = 2):
  198.         img = self.imResize(img)
  199.         if vis == True:
  200.             features, hog_image = feature.hog(img[:,:,0],orientations=orient, pixels_per_cell= (cellSize,cellSize),
  201.                                       cells_per_block=(cellPerBlock,cellPerBlock),
  202.                                       block_norm = 'L2-Hys', visualise = True, transform_sqrt = True,
  203.                                       feature_vector = True, normalise = None)
  204.             plt.imshow(hog_image)
  205.             plt.title('Hog image')
  206.             plt.show()
  207.  
  208.             return features
  209.         else:
  210.             features = feature.hog(img[:,:,0],orientations=orient, pixels_per_cell= (cellSize,cellSize),
  211.                                       cells_per_block=(cellPerBlock,cellPerBlock),
  212.                                       block_norm = 'L2-Hys', visualise = False, transform_sqrt = True,
  213.                                       feature_vector = True, normalise = None)
  214.             return features
  215.         pass
  216.  
  217.     def getBin(self,img, size=(32,32)):
  218.         return self.imResize(img,size).ravel()
  219.  
  220.     def color_hist(self,img,nbins =32, value=(0,256)):
  221.         channel1_hist=np.histogram(img[:, :,0], bins=nbins,range= value)
  222.         channel2_hist=np.histogram(img[:, :,1], bins=nbins,range= value)
  223.         channel3_hist=np.histogram(img[:, :,2], bins=nbins,range= value)
  224.  
  225.         return np.concatenate((channel1_hist[0],channel2_hist[0],channel3_hist[0]))
  226.  
  227.     def image_feature(self,mat):
  228.         img_feature = []
  229.         img_feature.append(self.getBin(mat))
  230.         img_feature.append(self.color_hist(mat))
  231.         img_feature.append(self.computeHog(mat))
  232.  
  233.         return np.concatenate(img_feature)
  234.  
  235.     def extractFeature(self, listOfPathImage):
  236.  
  237.         features = []
  238.  
  239.         for path in listOfPathImage:
  240.             img  = PhotoService.getImageIMG(path)
  241.             imgFeature = self.image_feature(img)
  242.             features.append(imgFeature)
  243.         return features
  244.  
  245.     def prepareForImageProcess(self,positiveListPaths,negativeListPaths):
  246.         cars = self.extractFeature(positiveListPaths)
  247.         nocars = self.extractFeature(negativeListPaths)
  248.  
  249.         unscaled_x = np.vstack((cars,nocars)).astype(np.float64)# TODO SAVE TO XML
  250.         scaler =StandardScaler().fit(unscaled_x)
  251.         X = scaler.transform(unscaled_x)
  252.         Y=np.hstack((np.ones(len(cars)), np.zeros(len(nocars))))
  253.  
  254.         return scaler,X,Y
  255.  
  256.     def prepareData(self, positiveListPaths, negativeListPaths):
  257.  
  258.         return self.prepareForImageProcess(positiveListPaths, negativeListPaths)
  259.  
  260.     def test(self):
  261.         print('test')
  262.  
  263. class SVMTrainClass():
  264.     svc = None
  265.     def __init__(self):
  266.         self.svc = LinearSVC()
  267.     def trainSVC(self,X,Y):
  268.         print('Start Trainig')
  269.         X_train, X_test, y_train, y_test= train_test_split(X, Y, test_size=0.2, random_state=random.randint(1,100))
  270.         self.svc.fit(X,Y)
  271.         accuracy = self.svc.score(X_test, y_test)
  272.         print('End Trainig')
  273.  
  274.     def getSVC(self):
  275.         return self.svc
  276.  
  277. class BinaryClassifier:
  278.     def __init__(self, svc, scaler):
  279.         self.svc, self.scaler = svc, scaler
  280.  
  281.     def predict(self, f):
  282.         f = self.scaler.transform([f])
  283.         r = self.svc.predict(f)
  284.         return np.int(r[0])
  285.  
  286. def main(argv):
  287.     # POBIERANIE ZDJĘĆ
  288.     photoList = PhotoService.getUrlPhotos()
  289.     positiveIndex = PhotoService.getPositiveIndex(photoList[0][0])
  290.     positiveListPaths = photoList[positiveIndex[0]]
  291.     negativeListPaths = photoList[positiveIndex[1]]
  292.  
  293.     #
  294.     # # UCZENIE
  295.     sVMFeature = SVMFeature()
  296.     scaler, X, Y = sVMFeature.prepareData(positiveListPaths,negativeListPaths)
  297.  
  298.     sVMTrainClass = SVMTrainClass()
  299.     sVMTrainClass.trainSVC(X,Y)
  300.  
  301.     # Otwieranie filmu ;)
  302.     openCVFrame =  OpenCVFrame(sVMTrainClass.getSVC(),scaler,sVMFeature)
  303.     # openCVFrame =  OpenCVFrame()
  304.     openCVFrame.runWindowSlider()
  305.  
  306.  
  307.  
  308.  
  309.  
  310.  
  311. if __name__ == "__main__":
  312.     main(sys.argv)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement