SHARE
TWEET

Untitled

a guest Jun 27th, 2019 71 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from pelee_net import PeleeNet
  2. import tensorflow as tf
  3. import tensorflow.keras as keras
  4. import numpy as np
  5. from PIL import Image
  6. import pickle
  7. import os
  8. from tensorflow.contrib.tpu.python.tpu import keras_support
  9.  
  10. def generator(X, y, batch_size, use_augmentation, shuffle, scale):
  11.     if use_augmentation:
  12.         base_gen = keras.preprocessing.image.ImageDataGenerator(
  13.             horizontal_flip=True,
  14.             width_shift_range=4.0/32.0,
  15.             height_shift_range=4.0/32.0)
  16.     else:
  17.         base_gen = keras.preprocessing.image.ImageDataGenerator()
  18.     for X_base, y_base in base_gen.flow(X, y, batch_size=batch_size, shuffle=shuffle):
  19.         if scale != 1:
  20.             X_batch = np.zeros((X_base.shape[0], X_base.shape[1]*scale,
  21.                                 X_base.shape[2]*scale, X_base.shape[3]), np.float32)
  22.             for i in range(X_base.shape[0]):
  23.                 with Image.fromarray(X_base[i].astype(np.uint8)) as img:
  24.                     img = img.resize((X_base.shape[1]*scale, X_base.shape[2]*scale), Image.LANCZOS)
  25.                     X_batch[i] = np.asarray(img, np.float32) / 255.0
  26.         else:
  27.             X_batch = X_base / 255.0
  28.         yield X_batch, y_base
  29.  
  30. def lr_scheduler(epoch):
  31.     x = 0.4
  32.     if epoch >= 70: x /= 5.0
  33.     if epoch >= 120: x /= 5.0
  34.     if epoch >= 170: x /= 5.0
  35.     return x
  36.  
  37. def train(use_augmentation, use_stem_block):
  38.     tf.logging.set_verbosity(tf.logging.FATAL)
  39.     (X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
  40.     y_train = keras.utils.to_categorical(y_train)
  41.     y_test = keras.utils.to_categorical(y_test)
  42.  
  43.     # generator
  44.     batch_size = 512
  45.     scale = 7 if use_stem_block else 1
  46.     train_gen = generator(X_train, y_train, batch_size=batch_size,
  47.                           use_augmentation=use_augmentation, shuffle=True, scale=scale)
  48.     test_gen = generator(X_test, y_test, batch_size=1000,
  49.                          use_augmentation=False, shuffle=False, scale=scale)
  50.    
  51.     # network
  52.     input_shape = (224,224,3) if use_stem_block else (32,32,3)
  53.     model = PeleeNet(input_shape=input_shape, use_stem_block=use_stem_block, n_classes=10)
  54.     model.compile(keras.optimizers.SGD(0.4, 0.9), "categorical_crossentropy", ["acc"])
  55.  
  56.     tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
  57.     tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
  58.     strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
  59.     model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
  60.  
  61.     scheduler = keras.callbacks.LearningRateScheduler(lr_scheduler)
  62.     hist = keras.callbacks.History()
  63.  
  64.     model.fit_generator(train_gen, steps_per_epoch=X_train.shape[0]//batch_size,
  65.                         validation_data=test_gen, validation_steps=X_test.shape[0]//1000,
  66.                         callbacks=[scheduler, hist], epochs=1, max_queue_size=1)
  67.     history = hist.history
  68.     with open(f"pelee_aug_{use_augmentation}_stem_{use_stem_block}.pkl", "wb") as fp:
  69.         pickle.dump(history, fp)
  70.  
  71. if __name__ == "__main__":
  72.     train(True, True)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top