Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # define VoVNet
- # If you use TPU, only compatible with TensorFlow v.1.13 or earlier
- import os
- import tensorflow as tf
- from tensorflow.contrib.tpu.python.tpu import keras_support
- from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dense, GlobalAveragePooling2D, Concatenate, Input
- from tensorflow.keras.models import Model
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
- from tensorflow.keras.callbacks import History, LearningRateScheduler
- import pickle
- VOV_27_ARCHITECURE_CONFIG = [
- [64, 80, 96, 112],
- [128, 256, 384, 512],
- [1, 1, 1, 1]
- ]
- VOV_39_ARCHITECURE_CONFIG = [
- [128, 160, 192, 224],
- [256, 512, 768, 1024],
- [1, 1, 2, 2]
- ]
- VOV_57_ARCHITECURE_CONFIG = [
- [128, 160, 192, 224],
- [256, 512, 768, 1024],
- [1, 1, 4, 3]
- ]
- class VoVNet:
- def __init__(self, architecture_config):
- self.channels, self.bottleneck, self.repetitions = architecture_config
- self.model = self.make_model()
- def conv_bn_relu(self, input_tensor, channels, strides=1, kernel=3):
- x = input_tensor
- x = Conv2D(channels, kernel, strides=strides, padding='same')(x)
- x = BatchNormalization()(x)
- x = Activation("relu")(x)
- return x
- def OSAModule(self, input_tensor, channel, bottleneck, aggr_times=5):
- x = input_tensor
- aggr = []
- for i in range(aggr_times):
- x = self.conv_bn_relu(x, channel)
- aggr.append(x)
- x = Concatenate()(aggr)
- x = self.conv_bn_relu(x, bottleneck, kernel=1)
- return x
- def make_model(self):
- # 学習したいデータの解像度に合わせて調整する
- inputs = Input(shape=(300, 300, 3))
- x = inputs
- # stem stage
- x = self.conv_bn_relu(x, 64, strides=2)
- x = self.conv_bn_relu(x, 64, strides=1)
- x = self.conv_bn_relu(x, 128, strides=1)
- x = MaxPooling2D((3, 3), 2)(x)
- # OSA stage
- for chan, bottleneck, rep in zip(self.channels, self.bottlenecks, self.repetitions):
- for _ in range(rep):
- x = self.OSAModule(x, chan, bottleneck)
- x = MaxPooling2D((3, 3), 2)(x)
- x = GlobalAveragePooling2D()(x)
- x = Dense(10, activation="softmax")(x)
- outputs = x
- model = Model(inputs, outputs)
- print(model.summary())
- return model
- def train(self, X_train, y_train, X_val, y_val, epochs=1, use_tpu=False, batch_size=128):
- self.model.compile(
- optimizer=Adam(), loss='categorical_crossentropy', metrics=["acc"])
- if use_tpu:
- tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
- tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
- tpu_grpc_url)
- strategy = keras_support.TPUDistributionStrategy(
- tpu_cluster_resolver)
- self.model = tf.contrib.tpu.keras_to_tpu_model(
- self.model, strategy=strategy)
- datagen = ImageDataGenerator(
- rotation_range=20,
- width_shift_range=0.2,
- height_shift_range=0.2,
- channel_shift_range=50,
- horizontal_flip=True)
- def lr_scheduler(epoch):
- # Just the same as initial lr of Adam
- return 0.001
- # define callbacks
- scheduler = LearningRateScheduler(lr_scheduler)
- # 訓練: epoch数のベンチマーク用
- #history = self.model.fit(X_train, y_train, batch_size=128, epochs=epochs, validation_data=(X_val, y_val)).history
- #水増しありの訓練
- history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
- steps_per_epoch=len(X_train) // batch_size, validation_data=(X_val, y_val), epochs=epochs, callbacks=[scheduler]).history
- # 保存
- with open("history.dat", "wb") as fp:
- pickle.dump(history, fp)
- def evaluate(self, *args, **kwargs):
- # return self.model.sync_to_cpu().evaluate(*args, **kwargs)
- return self.model.evaluate(*args, **kwargs)
- def predict(self, *args, **kwargs):
- # return self.model.sync_to_cpu().evaluate(*args, **kwargs)
- return self.model.predict(*args, **kwargs)
- # Define your model by:
- # model = VoVNet(VOV_27_ARCHITECURE_CONFIG)
- # Train by:
- # model.train(train_images, train_labels, test_images, test_labels, epochs=100)
- # or something like this:
- # model.train(train_images, train_labels, test_images, test_labels, epochs=100, use_tpu=True, batch_size=10000)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement