Advertisement
Guest User

Untitled

a guest
May 27th, 2016
52
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.60 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import numpy as np
  6. np.random.seed(2 ** 10)
  7.  
  8. # Prevent reaching to maximum recursion depth in `theano.tensor.grad`
  9. # import sys
  10. # sys.setrecursionlimit(2 ** 20)
  11.  
  12. from six.moves import range
  13.  
  14. from keras.datasets import cifar10
  15. from keras.layers import Input, Dense, Layer, merge, Activation, Flatten, Lambda
  16. from keras.layers.convolutional import Convolution2D, AveragePooling2D
  17. from keras.layers.normalization import BatchNormalization
  18. from keras.models import Model
  19. from keras.optimizers import SGD
  20. from keras.regularizers import l2
  21. from keras.callbacks import Callback, LearningRateScheduler
  22. from keras.preprocessing.image import ImageDataGenerator
  23. from keras.utils import np_utils
  24. import keras.backend as K
  25.  
  26.  
  27. batch_size = 64
  28. nb_classes = 10
  29. nb_epoch = 500
  30. N = 18
  31. weight_decay = 1e-4
  32. lr_schedule = [0.5, 0.75]
  33.  
  34. death_mode = "lin_decay" # or uniform
  35. death_rate = 0.5
  36.  
  37. img_rows, img_cols = 32, 32
  38. img_channels = 3
  39.  
  40.  
  41. add_tables = []
  42.  
  43. inputs = Input(shape=(img_channels, img_rows, img_cols))
  44.  
  45. net = Convolution2D(16, 3, 3, border_mode="same", W_regularizer=l2(weight_decay))(inputs)
  46. net = BatchNormalization(axis=1)(net)
  47. net = Activation("relu")(net)
  48.  
  49.  
  50. def residual_drop(x, input_shape, output_shape, strides=(1, 1)):
  51. global add_tables
  52.  
  53. nb_filter = output_shape[0]
  54. conv = Convolution2D(nb_filter, 3, 3, subsample=strides,
  55. border_mode="same", W_regularizer=l2(weight_decay))(x)
  56. conv = BatchNormalization(axis=1)(conv)
  57. conv = Activation("relu")(conv)
  58. conv = Convolution2D(nb_filter, 3, 3,
  59. border_mode="same", W_regularizer=l2(weight_decay))(conv)
  60. conv = BatchNormalization(axis=1)(conv)
  61.  
  62. if strides[0] >= 2:
  63. x = AveragePooling2D(strides)(x)
  64.  
  65. if (output_shape[0] - input_shape[0]) > 0:
  66. pad_shape = (1,
  67. output_shape[0] - input_shape[0],
  68. output_shape[1],
  69. output_shape[2])
  70. padding = K.zeros(pad_shape)
  71. padding = K.repeat_elements(padding, K.shape(x)[0], axis=0)
  72. x = Lambda(lambda y: K.concatenate([y, padding], axis=1),
  73. output_shape=output_shape)(x)
  74.  
  75. _death_rate = K.variable(death_rate)
  76. scale = K.ones_like(conv) - _death_rate
  77. conv = Lambda(lambda c: K.in_test_phase(scale * c, c),
  78. output_shape=output_shape)(conv)
  79.  
  80. out = merge([conv, x], mode="sum")
  81. out = Activation("relu")(out)
  82.  
  83. gate = K.variable(1, dtype="uint8")
  84. add_tables += [{"death_rate": _death_rate, "gate": gate}]
  85. return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]),
  86. output_shape=output_shape)([out, x])
  87.  
  88.  
  89. for i in range(N):
  90. net = residual_drop(net, input_shape=(16, 32, 32), output_shape=(16, 32, 32))
  91.  
  92. net = residual_drop(
  93. net,
  94. input_shape=(16, 32, 32),
  95. output_shape=(32, 16, 16),
  96. strides=(2, 2)
  97. )
  98. for i in range(N - 1):
  99. net = residual_drop(
  100. net,
  101. input_shape=(32, 16, 16),
  102. output_shape=(32, 16, 16)
  103. )
  104.  
  105. net = residual_drop(
  106. net,
  107. input_shape=(32, 16, 16),
  108. output_shape=(64, 8, 8),
  109. strides=(2, 2)
  110. )
  111. for i in range(N - 1):
  112. net = residual_drop(
  113. net,
  114. input_shape=(64, 8, 8),
  115. output_shape=(64, 8, 8)
  116. )
  117.  
  118. pool = AveragePooling2D((8, 8))(net)
  119. flatten = Flatten()(pool)
  120.  
  121. predictions = Dense(10, activation="softmax", W_regularizer=l2(weight_decay))(flatten)
  122. model = Model(input=inputs, output=predictions)
  123.  
  124. sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)
  125. model.compile(optimizer=sgd, loss="categorical_crossentropy")
  126.  
  127.  
  128. def open_all_gates():
  129. for t in add_tables:
  130. K.set_value(t["gate"], 1)
  131.  
  132.  
  133. # setup death rate
  134. for i, tb in enumerate(add_tables, start=1):
  135. if death_mode == "uniform":
  136. K.set_value(tb["death_rate"], death_rate)
  137. elif death_mode == "lin_decay":
  138. K.set_value(tb["death_rate"], i / len(add_tables) * death_rate)
  139. else:
  140. raise
  141.  
  142.  
  143. class GatesUpdate(Callback):
  144. def on_batch_begin(self, batch, logs={}):
  145. open_all_gates()
  146.  
  147. rands = np.random.uniform(size=len(add_tables))
  148. for t, rand in zip(add_tables, rands):
  149. if rand < K.get_value(t["death_rate"]):
  150. K.set_value(t["gate"], 0)
  151.  
  152. def on_batch_end(self, batch, logs={}):
  153. open_all_gates() # for validation
  154.  
  155.  
  156. def schedule(epoch_idx):
  157. if (epoch_idx + 1) < (nb_epoch * lr_schedule[0]):
  158. return 0.1
  159. elif (epoch_idx + 1) < (nb_epoch * lr_schedule[1]):
  160. return 0.01
  161.  
  162. return 0.001
  163.  
  164.  
  165. with open('model.yaml', 'w') as f:
  166. # f.write(model.to_json())
  167. f.write( model.to_yaml() )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement