Need a unique gift idea?
A Pastebin account makes a great Christmas gift
SHARE
TWEET

Untitled

a guest Nov 14th, 2018 75 Never
Upgrade to PRO!
ENDING IN00days00hours00mins00secs
 
  1. import keras
  2. from .mobilenet_custom import MobileNet
  3. from .densenet_keras import DenseNet201
  4. from .densenet import DenseNetImageNet161, DenseNetImageNet264
  5. from .mobilenet4x4_custom import MobileNet4x4
  6. #from .resnet import ResnetBuilder
  7. from keras.models import Model, load_model
  8. from keras.regularizers import l2
  9. from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Flatten, Input, Concatenate
  10.  
  11. _ALLOWED_FE = {
  12.     'resnet50': keras.applications.resnet50.ResNet50,
  13.     'inception_v3': keras.applications.inception_v3.InceptionV3,
  14.     #'nasnet_large': keras.applications.nasnet.NASNetLarge,
  15.     #'nasnet_mobile': keras.applications.nasnet.NASNetMobile,
  16.     'xception': keras.applications.xception.Xception,
  17.     'mobilenet': MobileNet,
  18.     'densenet161': DenseNetImageNet161,
  19.     'densenet201': DenseNet201,
  20.     'densenet264': DenseNetImageNet264,
  21.     #'mobilenet4x4': MobileNet4x4
  22. }
  23.  
  24.  
  25. def build_model(fe, input_size=(512, 512), hinge=False, data_format='channels_last'):
  26.     if fe not in _ALLOWED_FE.keys():
  27.         raise ValueError('{} feature extractor is not supported'.format(fe))
  28.     fe_class = _ALLOWED_FE[fe]
  29.     weights = 'imagenet'
  30.     if fe.startswith('nasnet') or fe == 'mobilenet4x4':
  31.         weights = None
  32.     aux_input = Input(shape=[64], name='aux_input')
  33.     base_model = fe_class(weights=weights, include_top=False)
  34.     #base_model = load_model('mobilenet4x4.h5', custom_objects={'relu6': keras.applications.mobilenet.relu6, 'DepthwiseConv2D': keras.applications.mobilenet.DepthwiseConv2D})
  35. #   for layer in base_model.layers:
  36. #       layer.trainable = False
  37.     # print(base_model.summary())
  38.     x = base_model.output
  39.     #import ipdb; ipdb.set_trace()
  40.     #if not fe.startswith('densenet'):
  41.     x = GlobalAveragePooling2D()(x)
  42.     x = Dropout(0.5)(x)
  43.     x = Concatenate()([x, aux_input])
  44.     x = Dense(2048, activation='relu')(x)
  45.     if hinge:
  46.         predictions = Dense(10, W_regularizer=l2(0.01), activation='linear')(x)
  47.     else:
  48.         predictions = Dense(10, activation='softmax')(x)
  49.     model = Model(inputs=[base_model.input, aux_input], outputs=predictions)
  50.     return model
  51.     # return base_model
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top