Advertisement
Guest User

Untitled

a guest
Jul 20th, 2019
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.55 KB | None | 0 0
  1. # define VoVNet
  2. # If you use TPU, only compatible with TensorFlow v.1.13 or earlier
  3. import os
  4. import tensorflow as tf
  5. from tensorflow.contrib.tpu.python.tpu import keras_support
  6. from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dense, GlobalAveragePooling2D, Concatenate, Input
  7. from tensorflow.keras.models import Model
  8. from tensorflow.keras.optimizers import Adam
  9. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  10. from tensorflow.keras.callbacks import History, LearningRateScheduler
  11. import pickle
  12.  
  13. VOV_27_ARCHITECURE_CONFIG = [
  14. [64, 80, 96, 112],
  15. [128, 256, 384, 512],
  16. [1, 1, 1, 1]
  17. ]
  18. VOV_39_ARCHITECURE_CONFIG = [
  19. [128, 160, 192, 224],
  20. [256, 512, 768, 1024],
  21. [1, 1, 2, 2]
  22. ]
  23. VOV_57_ARCHITECURE_CONFIG = [
  24. [128, 160, 192, 224],
  25. [256, 512, 768, 1024],
  26. [1, 1, 4, 3]
  27. ]
  28.  
  29.  
  30. class VoVNet:
  31. def __init__(self, architecture_config):
  32. self.channels, self.bottleneck, self.repetitions = architecture_config
  33. self.model = self.make_model()
  34.  
  35. def conv_bn_relu(self, input_tensor, channels, strides=1, kernel=3):
  36. x = input_tensor
  37. x = Conv2D(channels, kernel, strides=strides, padding='same')(x)
  38. x = BatchNormalization()(x)
  39. x = Activation("relu")(x)
  40. return x
  41.  
  42. def OSAModule(self, input_tensor, channel, bottleneck, aggr_times=5):
  43. x = input_tensor
  44. aggr = []
  45. for i in range(aggr_times):
  46. x = self.conv_bn_relu(x, channel)
  47. aggr.append(x)
  48.  
  49. x = Concatenate()(aggr)
  50. x = self.conv_bn_relu(x, bottleneck, kernel=1)
  51. return x
  52.  
  53. def make_model(self):
  54. # 学習したいデータの解像度に合わせて調整する
  55. inputs = Input(shape=(300, 300, 3))
  56. x = inputs
  57.  
  58. # stem stage
  59. x = self.conv_bn_relu(x, 64, strides=2)
  60. x = self.conv_bn_relu(x, 64, strides=1)
  61. x = self.conv_bn_relu(x, 128, strides=1)
  62. x = MaxPooling2D((3, 3), 2)(x)
  63.  
  64. # OSA stage
  65. for chan, bottleneck, rep in zip(self.channels, self.bottlenecks, self.repetitions):
  66. for _ in range(rep):
  67. x = self.OSAModule(x, chan, bottleneck)
  68. x = MaxPooling2D((3, 3), 2)(x)
  69.  
  70. x = GlobalAveragePooling2D()(x)
  71. x = Dense(10, activation="softmax")(x)
  72. outputs = x
  73. model = Model(inputs, outputs)
  74. print(model.summary())
  75. return model
  76.  
  77. def train(self, X_train, y_train, X_val, y_val, epochs=1, use_tpu=False, batch_size=128):
  78. self.model.compile(
  79. optimizer=Adam(), loss='categorical_crossentropy', metrics=["acc"])
  80.  
  81. if use_tpu:
  82. tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
  83. tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
  84. tpu_grpc_url)
  85. strategy = keras_support.TPUDistributionStrategy(
  86. tpu_cluster_resolver)
  87. self.model = tf.contrib.tpu.keras_to_tpu_model(
  88. self.model, strategy=strategy)
  89.  
  90. datagen = ImageDataGenerator(
  91. rotation_range=20,
  92. width_shift_range=0.2,
  93. height_shift_range=0.2,
  94. channel_shift_range=50,
  95. horizontal_flip=True)
  96.  
  97. def lr_scheduler(epoch):
  98. # Just the same as initial lr of Adam
  99. return 0.001
  100.  
  101. # define callbacks
  102. scheduler = LearningRateScheduler(lr_scheduler)
  103.  
  104. # 訓練: epoch数のベンチマーク用
  105. #history = self.model.fit(X_train, y_train, batch_size=128, epochs=epochs, validation_data=(X_val, y_val)).history
  106. #水増しありの訓練
  107. history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
  108. steps_per_epoch=len(X_train) // batch_size, validation_data=(X_val, y_val), epochs=epochs, callbacks=[scheduler]).history
  109.  
  110. # 保存
  111. with open("history.dat", "wb") as fp:
  112. pickle.dump(history, fp)
  113.  
  114. def evaluate(self, *args, **kwargs):
  115. # return self.model.sync_to_cpu().evaluate(*args, **kwargs)
  116. return self.model.evaluate(*args, **kwargs)
  117.  
  118. def predict(self, *args, **kwargs):
  119. # return self.model.sync_to_cpu().evaluate(*args, **kwargs)
  120. return self.model.predict(*args, **kwargs)
  121.  
  122. # Define your model by:
  123. # model = VoVNet(VOV_27_ARCHITECURE_CONFIG)
  124.  
  125. # Train by:
  126. # model.train(train_images, train_labels, test_images, test_labels, epochs=100)
  127. # or something like this:
  128. # model.train(train_images, train_labels, test_images, test_labels, epochs=100, use_tpu=True, batch_size=10000)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement