Advertisement
Guest User

Untitled

a guest
Jun 27th, 2019
144
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.01 KB | None | 0 0
  1. from pelee_net import PeleeNet
  2. import tensorflow as tf
  3. import tensorflow.keras as keras
  4. import numpy as np
  5. from PIL import Image
  6. import pickle
  7. import os
  8. from tensorflow.contrib.tpu.python.tpu import keras_support
  9.  
  10. def generator(X, y, batch_size, use_augmentation, shuffle, scale):
  11. if use_augmentation:
  12. base_gen = keras.preprocessing.image.ImageDataGenerator(
  13. horizontal_flip=True,
  14. width_shift_range=4.0/32.0,
  15. height_shift_range=4.0/32.0)
  16. else:
  17. base_gen = keras.preprocessing.image.ImageDataGenerator()
  18. for X_base, y_base in base_gen.flow(X, y, batch_size=batch_size, shuffle=shuffle):
  19. if scale != 1:
  20. X_batch = np.zeros((X_base.shape[0], X_base.shape[1]*scale,
  21. X_base.shape[2]*scale, X_base.shape[3]), np.float32)
  22. for i in range(X_base.shape[0]):
  23. with Image.fromarray(X_base[i].astype(np.uint8)) as img:
  24. img = img.resize((X_base.shape[1]*scale, X_base.shape[2]*scale), Image.LANCZOS)
  25. X_batch[i] = np.asarray(img, np.float32) / 255.0
  26. else:
  27. X_batch = X_base / 255.0
  28. yield X_batch, y_base
  29.  
  30. def lr_scheduler(epoch):
  31. x = 0.4
  32. if epoch >= 70: x /= 5.0
  33. if epoch >= 120: x /= 5.0
  34. if epoch >= 170: x /= 5.0
  35. return x
  36.  
  37. def train(use_augmentation, use_stem_block):
  38. tf.logging.set_verbosity(tf.logging.FATAL)
  39. (X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
  40. y_train = keras.utils.to_categorical(y_train)
  41. y_test = keras.utils.to_categorical(y_test)
  42.  
  43. # generator
  44. batch_size = 512
  45. scale = 7 if use_stem_block else 1
  46. train_gen = generator(X_train, y_train, batch_size=batch_size,
  47. use_augmentation=use_augmentation, shuffle=True, scale=scale)
  48. test_gen = generator(X_test, y_test, batch_size=1000,
  49. use_augmentation=False, shuffle=False, scale=scale)
  50.  
  51. # network
  52. input_shape = (224,224,3) if use_stem_block else (32,32,3)
  53. model = PeleeNet(input_shape=input_shape, use_stem_block=use_stem_block, n_classes=10)
  54. model.compile(keras.optimizers.SGD(0.4, 0.9), "categorical_crossentropy", ["acc"])
  55.  
  56. tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
  57. tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
  58. strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
  59. model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
  60.  
  61. scheduler = keras.callbacks.LearningRateScheduler(lr_scheduler)
  62. hist = keras.callbacks.History()
  63.  
  64. model.fit_generator(train_gen, steps_per_epoch=X_train.shape[0]//batch_size,
  65. validation_data=test_gen, validation_steps=X_test.shape[0]//1000,
  66. callbacks=[scheduler, hist], epochs=1, max_queue_size=1)
  67. history = hist.history
  68. with open(f"pelee_aug_{use_augmentation}_stem_{use_stem_block}.pkl", "wb") as fp:
  69. pickle.dump(history, fp)
  70.  
  71. if __name__ == "__main__":
  72. train(True, True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement