daily pastebin goal
93%
SHARE
TWEET

Untitled

a guest Nov 14th, 2018 76 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import keras
  2. from .mobilenet_custom import MobileNet
  3. from .densenet_keras import DenseNet201
  4. from .densenet import DenseNetImageNet161, DenseNetImageNet264
  5. from .mobilenet4x4_custom import MobileNet4x4
  6. #from .resnet import ResnetBuilder
  7. from keras.models import Model, load_model
  8. from keras.regularizers import l2
  9. from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Flatten, Input, Concatenate
  10.  
  11. _ALLOWED_FE = {
  12.     'resnet50': keras.applications.resnet50.ResNet50,
  13.     'inception_v3': keras.applications.inception_v3.InceptionV3,
  14.     #'nasnet_large': keras.applications.nasnet.NASNetLarge,
  15.     #'nasnet_mobile': keras.applications.nasnet.NASNetMobile,
  16.     'xception': keras.applications.xception.Xception,
  17.     'mobilenet': MobileNet,
  18.     'densenet161': DenseNetImageNet161,
  19.     'densenet201': DenseNet201,
  20.     'densenet264': DenseNetImageNet264,
  21.     #'mobilenet4x4': MobileNet4x4
  22. }
  23.  
  24.  
  25. def build_model(fe, input_size=(512, 512), hinge=False, data_format='channels_last'):
  26.     if fe not in _ALLOWED_FE.keys():
  27.         raise ValueError('{} feature extractor is not supported'.format(fe))
  28.     fe_class = _ALLOWED_FE[fe]
  29.     weights = 'imagenet'
  30.     if fe.startswith('nasnet') or fe == 'mobilenet4x4':
  31.         weights = None
  32.     aux_input = Input(shape=[64], name='aux_input')
  33.     base_model = fe_class(weights=weights, include_top=False)
  34.     #base_model = load_model('mobilenet4x4.h5', custom_objects={'relu6': keras.applications.mobilenet.relu6, 'DepthwiseConv2D': keras.applications.mobilenet.DepthwiseConv2D})
  35. #   for layer in base_model.layers:
  36. #       layer.trainable = False
  37.     # print(base_model.summary())
  38.     x = base_model.output
  39.     #import ipdb; ipdb.set_trace()
  40.     #if not fe.startswith('densenet'):
  41.     x = GlobalAveragePooling2D()(x)
  42.     x = Dropout(0.5)(x)
  43.     x = Concatenate()([x, aux_input])
  44.     x = Dense(2048, activation='relu')(x)
  45.     if hinge:
  46.         predictions = Dense(10, W_regularizer=l2(0.01), activation='linear')(x)
  47.     else:
  48.         predictions = Dense(10, activation='softmax')(x)
  49.     model = Model(inputs=[base_model.input, aux_input], outputs=predictions)
  50.     return model
  51.     # return base_model
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top