Advertisement
Guest User

Untitled

a guest
Feb 24th, 2018
169
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.13 KB | None | 0 0
  1. import pygame
  2. import keras
  3. from keras.layers import Conv2D, MaxPooling2D, Input, Dense, Flatten, Dropout, LSTM
  4. from keras.models import Sequential
  5. import cv2
  6. import numpy as np
  7. import pickle
  8. import os
  9. from keras import backend as K
  10. from keras.models import load_model
  11. import random
  12. user_train = False
  13. model_train = False
  14. if user_train:
  15.     pygame.init()
  16.  
  17.     screen = pygame.display.set_mode([455,256])
  18.  
  19.     imgs = list(["output/"+fl for fl in os.listdir("output")])
  20.     for i in imgs:
  21.         if not i.endswith(".png"):
  22.             imgs.remove(i)
  23.     img_obs = list([pygame.image.load(img) for img in imgs])
  24.     img_index = 0
  25.     dataset = {}
  26.     hold = None
  27.     run_mainloop = True
  28.     while True:
  29.         m_pos = pygame.mouse.get_pos()
  30.         for event in pygame.event.get():
  31.             if event.type == pygame.QUIT:
  32.                 pygame.quit()
  33.                 raise SystemExit
  34.             if event.type == pygame.MOUSEBUTTONDOWN:
  35.                 dataset[imgs[img_index]] = m_pos
  36.                 img_index+=1
  37.                 clicks = []
  38.                 print(dataset)
  39.                 if img_index == len(imgs):
  40.                     pygame.quit()
  41.                     run_mainloop = False
  42.  
  43.         if not run_mainloop:
  44.             break
  45.  
  46.         screen.blit(img_obs[img_index],[0,0])
  47.         pygame.display.flip()
  48.  
  49.     with open("dataset_raw.txt","wb") as fl:
  50.         pickle.dump(dataset,fl)
  51. else:
  52.     with open("dataset_raw.txt","rb") as fl:
  53.         dataset = pickle.load(fl)
  54.  
  55. train_x = []
  56. train_y = []
  57. i = 0
  58. for imgfile in dataset:
  59.     img_src = cv2.imread(imgfile,0).flatten()/255
  60.  
  61.     train_x.append(img_src)
  62.     pos = dataset[imgfile]
  63.     train_y.append(np.array([pos[0]/455,pos[1]/256]))
  64. if model_train:
  65.     train_x = np.array(train_x)
  66.     train_y = np.array(train_y)
  67.     print(train_x)
  68.     print(train_y)
  69.  
  70.     model = Sequential()
  71.     model.add(Dense(units=512,input_dim=256*455,activation='relu'))
  72.     model.add(Dense(units=512,activation='relu'))
  73.     model.add(Dense(units=512,activation='relu'))
  74.     model.add(Dense(units=256,activation='relu'))
  75.     model.add(Dense(units=256,activation='relu'))
  76.     model.add(Dense(units=2,activation='tanh'))
  77.     model.compile(loss="mean_squared_error",
  78.                   optimizer='sgd',
  79.                   metrics=['accuracy'])
  80.     last_loss = 99999999999999 # Don't question it
  81.     while 1:
  82.         model.fit(train_x,train_y,epochs=5)
  83.         loss = model.evaluate(train_x, train_y, batch_size=128,verbose=False)[0]
  84.  
  85.         model.save("desklamp.h5")
  86.         #cv2.imwrite("tmp.png",model.predict(train_x[0].reshape(1,256,455,1))[0])
  87. else:
  88.     model = load_model("desklamp.h5")
  89. cam = cv2.VideoCapture(0)
  90. while True:
  91.     ret_val, img = cam.read()
  92.     img = cv2.flip(img,1)
  93.     arr = cv2.cvtColor(cv2.resize(img, (455, 256), interpolation = cv2.INTER_CUBIC), cv2.COLOR_BGR2GRAY).flatten()/255
  94.     output = model.predict(arr.reshape((1,455*256)))[0]
  95.     pos = (int(output[0]*1280),int(output[1]*720))
  96.     cv2.circle(img,pos,16,[255,0,0],-1)
  97.     cv2.imshow("hackathon",img)
  98.     key = cv2.waitKey(1)
  99.     if key == 27:
  100.         cv2.destroyAllWindows()
  101.         break
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement