SHARE
TWEET

Untitled

a guest Feb 24th, 2018 83 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import pygame
  2. import keras
  3. from keras.layers import Conv2D, MaxPooling2D, Input, Dense, Flatten, Dropout, LSTM
  4. from keras.models import Sequential
  5. import cv2
  6. import numpy as np
  7. import pickle
  8. import os
  9. from keras import backend as K
  10. from keras.models import load_model
  11. import random
  12. user_train = False
  13. model_train = False
  14. if user_train:
  15.     pygame.init()
  16.  
  17.     screen = pygame.display.set_mode([455,256])
  18.  
  19.     imgs = list(["output/"+fl for fl in os.listdir("output")])
  20.     for i in imgs:
  21.         if not i.endswith(".png"):
  22.             imgs.remove(i)
  23.     img_obs = list([pygame.image.load(img) for img in imgs])
  24.     img_index = 0
  25.     dataset = {}
  26.     hold = None
  27.     run_mainloop = True
  28.     while True:
  29.         m_pos = pygame.mouse.get_pos()
  30.         for event in pygame.event.get():
  31.             if event.type == pygame.QUIT:
  32.                 pygame.quit()
  33.                 raise SystemExit
  34.             if event.type == pygame.MOUSEBUTTONDOWN:
  35.                 dataset[imgs[img_index]] = m_pos
  36.                 img_index+=1
  37.                 clicks = []
  38.                 print(dataset)
  39.                 if img_index == len(imgs):
  40.                     pygame.quit()
  41.                     run_mainloop = False
  42.  
  43.         if not run_mainloop:
  44.             break
  45.  
  46.         screen.blit(img_obs[img_index],[0,0])
  47.         pygame.display.flip()
  48.  
  49.     with open("dataset_raw.txt","wb") as fl:
  50.         pickle.dump(dataset,fl)
  51. else:
  52.     with open("dataset_raw.txt","rb") as fl:
  53.         dataset = pickle.load(fl)
  54.  
  55. train_x = []
  56. train_y = []
  57. i = 0
  58. for imgfile in dataset:
  59.     img_src = cv2.imread(imgfile,0).flatten()/255
  60.  
  61.     train_x.append(img_src)
  62.     pos = dataset[imgfile]
  63.     train_y.append(np.array([pos[0]/455,pos[1]/256]))
  64. if model_train:
  65.     train_x = np.array(train_x)
  66.     train_y = np.array(train_y)
  67.     print(train_x)
  68.     print(train_y)
  69.  
  70.     model = Sequential()
  71.     model.add(Dense(units=512,input_dim=256*455,activation='relu'))
  72.     model.add(Dense(units=512,activation='relu'))
  73.     model.add(Dense(units=512,activation='relu'))
  74.     model.add(Dense(units=256,activation='relu'))
  75.     model.add(Dense(units=256,activation='relu'))
  76.     model.add(Dense(units=2,activation='tanh'))
  77.     model.compile(loss="mean_squared_error",
  78.                   optimizer='sgd',
  79.                   metrics=['accuracy'])
  80.     last_loss = 99999999999999 # Don't question it
  81.     while 1:
  82.         model.fit(train_x,train_y,epochs=5)
  83.         loss = model.evaluate(train_x, train_y, batch_size=128,verbose=False)[0]
  84.  
  85.         model.save("desklamp.h5")
  86.         #cv2.imwrite("tmp.png",model.predict(train_x[0].reshape(1,256,455,1))[0])
  87. else:
  88.     model = load_model("desklamp.h5")
  89. cam = cv2.VideoCapture(0)
  90. while True:
  91.     ret_val, img = cam.read()
  92.     img = cv2.flip(img,1)
  93.     arr = cv2.cvtColor(cv2.resize(img, (455, 256), interpolation = cv2.INTER_CUBIC), cv2.COLOR_BGR2GRAY).flatten()/255
  94.     output = model.predict(arr.reshape((1,455*256)))[0]
  95.     pos = (int(output[0]*1280),int(output[1]*720))
  96.     cv2.circle(img,pos,16,[255,0,0],-1)
  97.     cv2.imshow("hackathon",img)
  98.     key = cv2.waitKey(1)
  99.     if key == 27:
  100.         cv2.destroyAllWindows()
  101.         break
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top