Advertisement
Guest User

Untitled

a guest
Nov 14th, 2018
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.14 KB | None | 0 0
  1. import keras
  2. from .mobilenet_custom import MobileNet
  3. from .densenet_keras import DenseNet201
  4. from .densenet import DenseNetImageNet161, DenseNetImageNet264
  5. from .mobilenet4x4_custom import MobileNet4x4
  6. #from .resnet import ResnetBuilder
  7. from keras.models import Model, load_model
  8. from keras.regularizers import l2
  9. from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Flatten, Input, Concatenate
  10.  
  11. _ALLOWED_FE = {
  12.     'resnet50': keras.applications.resnet50.ResNet50,
  13.     'inception_v3': keras.applications.inception_v3.InceptionV3,
  14.     #'nasnet_large': keras.applications.nasnet.NASNetLarge,
  15.     #'nasnet_mobile': keras.applications.nasnet.NASNetMobile,
  16.     'xception': keras.applications.xception.Xception,
  17.     'mobilenet': MobileNet,
  18.     'densenet161': DenseNetImageNet161,
  19.     'densenet201': DenseNet201,
  20.     'densenet264': DenseNetImageNet264,
  21.     #'mobilenet4x4': MobileNet4x4
  22. }
  23.  
  24.  
  25. def build_model(fe, input_size=(512, 512), hinge=False, data_format='channels_last'):
  26.     if fe not in _ALLOWED_FE.keys():
  27.         raise ValueError('{} feature extractor is not supported'.format(fe))
  28.     fe_class = _ALLOWED_FE[fe]
  29.     weights = 'imagenet'
  30.     if fe.startswith('nasnet') or fe == 'mobilenet4x4':
  31.         weights = None
  32.     aux_input = Input(shape=[64], name='aux_input')
  33.     base_model = fe_class(weights=weights, include_top=False)
  34.     #base_model = load_model('mobilenet4x4.h5', custom_objects={'relu6': keras.applications.mobilenet.relu6, 'DepthwiseConv2D': keras.applications.mobilenet.DepthwiseConv2D})
  35. #   for layer in base_model.layers:
  36. #       layer.trainable = False
  37.     # print(base_model.summary())
  38.     x = base_model.output
  39.     #import ipdb; ipdb.set_trace()
  40.     #if not fe.startswith('densenet'):
  41.     x = GlobalAveragePooling2D()(x)
  42.     x = Dropout(0.5)(x)
  43.     x = Concatenate()([x, aux_input])
  44.     x = Dense(2048, activation='relu')(x)
  45.     if hinge:
  46.         predictions = Dense(10, W_regularizer=l2(0.01), activation='linear')(x)
  47.     else:
  48.         predictions = Dense(10, activation='softmax')(x)
  49.     model = Model(inputs=[base_model.input, aux_input], outputs=predictions)
  50.     return model
  51.     # return base_model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement