Advertisement
Guest User

Untitled

a guest
Oct 15th, 2019
138
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.28 KB | None | 0 0
  1. from keras.models import Sequential
  2. from keras.layers import Dense, Dropout, Flatten
  3. from keras.layers.convolutional import Conv2D, MaxPooling2D
  4. from keras.layers.normalization import BatchNormalization
  5. from keras.utils import np_utils
  6. from skimage import transform
  7. from skimage.io import imread
  8. from os import listdir
  9. from os.path import join
  10. import numpy as np
  11.  
  12. imgSize = 50
  13.  
  14. def train_detector(train_gt, train_img_dir, fast_train=False):
  15.     inputData = np.zeros((len(train_gt), imgSize, imgSize, 3))
  16.     checkData = np.zeros((len(train_gt), 28)).astype(float)
  17.    
  18.     for i in train_gt:
  19.         numbI = int(i.split('.')[0])
  20.         img = imread(join(train_img_dir, i)).astype(float)
  21.         checkData[numbI] = np.array(train_gt[i])
  22.         checkData[numbI][::2] *= (imgSize / img.shape[1])
  23.         checkData[numbI][1::2] *= (imgSize / img.shape[0])
  24.         img = transform.resize(img, [imgSize, imgSize, 3])
  25.         inputData[numbI] = img
  26.    
  27.     expectedValue = inputData.mean(axis=(0,1,2))
  28.     std = inputData.var(axis=(0,1,2))
  29.     inputData = (inputData - expectedValue) / std
  30.  
  31.     model = Sequential()
  32.     stSize = imgSize
  33.     model.add(Conv2D(stSize, (3, 3), padding='same', input_shape=(imgSize, imgSize, 3), activation='relu'))
  34.     model.add(BatchNormalization())
  35.     model.add(Conv2D(stSize, (3, 3), activation='relu', padding='same'))
  36.     model.add(MaxPooling2D(pool_size=(2, 2)))
  37.     model.add(Dropout(0.25))
  38.  
  39.     model.add(Conv2D(stSize * 2, (3, 3), padding='same', activation='relu'))
  40.     model.add(BatchNormalization())
  41.     model.add(Conv2D(stSize * 2, (3, 3), activation='relu'))
  42.     model.add(BatchNormalization())
  43.     model.add(MaxPooling2D(pool_size=(2, 2)))
  44.     model.add(Dropout(0.25))
  45.  
  46.     model.add(Conv2D(stSize * 4, (3, 3), padding='same', activation='relu'))
  47.     model.add(BatchNormalization())
  48.     model.add(Conv2D(stSize * 4, (3, 3), activation='relu'))
  49.     model.add(BatchNormalization())
  50.     model.add(MaxPooling2D(pool_size=(2, 2)))
  51.     model.add(Dropout(0.25))
  52.    
  53.     model.add(Flatten())
  54.     model.add(Dense(1800, activation='relu'))
  55.     model.add(Dense(900, activation='relu'))
  56.     model.add(Dropout(0.5))
  57.     model.add(Dense(28))
  58.    
  59.     model.compile(loss="mse", optimizer="adam", metrics=["mae"])
  60.     print(model.summary())
  61.    
  62.     num_epochs = 100
  63.     if (fast_train):
  64.         num_epochs = 1
  65.     model.fit(inputData, checkData, batch_size=30, epochs=num_epochs, validation_split=0.1, shuffle=True)
  66.    
  67.     model.save("facepoints_model.hdf5")
  68.  
  69.  
  70. def detect(model, test_img_dir):
  71.     dir_files_list = listdir(test_img_dir)
  72.    
  73.     inputData = np.zeros((len(dir_files_list), imgSize, imgSize, 3))
  74.     changeSize = np.zeros((len(dir_files_list), 2))
  75.  
  76.     for i, filename in enumerate(dir_files_list):
  77.         img = imread(join(test_img_dir, filename)).astype(float)
  78.         changeSize[i] = np.array([img.shape[0], img.shape[1]])
  79.         img = transform.resize(img, [imgSize, imgSize, 3])
  80.         inputData[i] = img
  81.    
  82.     expectedValue = inputData.mean(axis=(0,1,2))
  83.     std = inputData.var(axis=(0,1,2))
  84.     inputData = (inputData - expectedValue) / std
  85.    
  86.     y = model.predict(inputData)
  87.     for i in range(y.shape[0]):
  88.         y[i, ::2] *= (changeSize[i][1] / imgSize)
  89.         y[i, 1::2] *= (changeSize[i][0] / imgSize)
  90.  
  91.     return {filename: list(map(int, y[i])) for i, filename in enumerate(dir_files_list)}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement