Guest User

Model Parallelism

a guest
Jul 5th, 2018
425
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.68 KB | None | 0 0
  1. def conv2d_bn(x,
  2.               filters,
  3.               num_row,
  4.               num_col,
  5.               padding='same',
  6.               strides=(1, 1),
  7.               name=None):
  8.     """Utility function to apply conv + BN.
  9.  
  10.    # Arguments
  11.        x: input tensor.
  12.        filters: filters in `Conv2D`.
  13.        num_row: height of the convolution kernel.
  14.        num_col: width of the convolution kernel.
  15.        padding: padding mode in `Conv2D`.
  16.        strides: strides in `Conv2D`.
  17.        name: name of the ops; will become `name + '_conv'`
  18.            for the convolution and `name + '_bn'` for the
  19.            batch norm layer.
  20.  
  21.    # Returns
  22.        Output tensor after applying `Conv2D` and `BatchNormalization`.
  23.    """
  24.     if name is not None:
  25.         bn_name = name + '_bn'
  26.         conv_name = name + '_conv'
  27.     else:
  28.         bn_name = None
  29.         conv_name = None
  30.    
  31.     bn_axis = 3
  32.     x = Conv2D(
  33.         filters, (num_row, num_col),
  34.         strides=strides,
  35.         padding=padding,
  36.         use_bias=False,
  37.         name=conv_name)(x)
  38.     x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
  39.     x = Activation('relu', name=name)(x)
  40.     return x
  41.  
  42.  
  43. def Gender_InceptionV3(input_shape=(512, 512, 3)):
  44.     img_input = Input(shape=input_shape)
  45.  
  46.     channel_axis = 3
  47.    
  48.     with tf.device('/gpu:0'):
  49.         x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='valid')
  50.         x = conv2d_bn(x, 32, 3, 3, padding='valid')
  51.         x = conv2d_bn(x, 64, 3, 3)
  52.         x = MaxPooling2D((3, 3), strides=(2, 2))(x)
  53.  
  54.         x = conv2d_bn(x, 80, 1, 1, padding='valid')
  55.         x = conv2d_bn(x, 192, 3, 3, padding='valid')
  56.         x = MaxPooling2D((3, 3), strides=(2, 2))(x)
  57.  
  58.         # mixed 0, 1, 2: 35 x 35 x 256
  59.         branch1x1 = conv2d_bn(x, 64, 1, 1)
  60.  
  61.         branch5x5 = conv2d_bn(x, 48, 1, 1)
  62.         branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
  63.  
  64.         branch3x3dbl = conv2d_bn(x, 64, 1, 1)
  65.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  66.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  67.  
  68.         branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
  69.         branch_pool = conv2d_bn(branch_pool, 32, 1, 1)
  70.         x = layers.concatenate(
  71.             [branch1x1, branch5x5, branch3x3dbl, branch_pool],
  72.             axis=channel_axis,
  73.             name='mixed0')
  74.         print(x)
  75.  
  76.     with tf.device('/gpu:1'):
  77.         # mixed 1: 35 x 35 x 256
  78.         branch1x1 = conv2d_bn(x, 64, 1, 1)
  79.  
  80.         branch5x5 = conv2d_bn(x, 48, 1, 1)
  81.         branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
  82.  
  83.         branch3x3dbl = conv2d_bn(x, 64, 1, 1)
  84.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  85.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  86.  
  87.         branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
  88.         branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
  89.         x = layers.concatenate(
  90.             [branch1x1, branch5x5, branch3x3dbl, branch_pool],
  91.             axis=channel_axis,
  92.             name='mixed1')
  93.  
  94.         # mixed 2: 35 x 35 x 256
  95.         branch1x1 = conv2d_bn(x, 64, 1, 1)
  96.  
  97.         branch5x5 = conv2d_bn(x, 48, 1, 1)
  98.         branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
  99.  
  100.         branch3x3dbl = conv2d_bn(x, 64, 1, 1)
  101.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  102.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  103.  
  104.         branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
  105.         branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
  106.         x = layers.concatenate(
  107.             [branch1x1, branch5x5, branch3x3dbl, branch_pool],
  108.             axis=channel_axis,
  109.             name='mixed2')
  110.         print(x)
  111.        
  112.     with tf.device('/gpu:2'):
  113.         # mixed 3: 17 x 17 x 768
  114.         branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='valid')
  115.  
  116.         branch3x3dbl = conv2d_bn(x, 64, 1, 1)
  117.         branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
  118.         branch3x3dbl = conv2d_bn(
  119.             branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='valid')
  120.  
  121.         branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x)
  122.         x = layers.concatenate(
  123.             [branch3x3, branch3x3dbl, branch_pool], axis=channel_axis, name='mixed3')
  124.  
  125.         # mixed 4: 17 x 17 x 768
  126.         branch1x1 = conv2d_bn(x, 192, 1, 1)
  127.  
  128.         branch7x7 = conv2d_bn(x, 128, 1, 1)
  129.         branch7x7 = conv2d_bn(branch7x7, 128, 1, 7)
  130.         branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
  131.  
  132.         branch7x7dbl = conv2d_bn(x, 128, 1, 1)
  133.         branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
  134.         branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7)
  135.         branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
  136.         branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
  137.  
  138.         branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
  139.         branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
  140.         x = layers.concatenate(
  141.             [branch1x1, branch7x7, branch7x7dbl, branch_pool],
  142.             axis=channel_axis,
  143.             name='mixed4')
  144.         print(x)
  145.  
  146.         # mixed 5, 6: 17 x 17 x 768
  147.         for i in range(2):
  148.             branch1x1 = conv2d_bn(x, 192, 1, 1)
  149.  
  150.             branch7x7 = conv2d_bn(x, 160, 1, 1)
  151.             branch7x7 = conv2d_bn(branch7x7, 160, 1, 7)
  152.             branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
  153.  
  154.             branch7x7dbl = conv2d_bn(x, 160, 1, 1)
  155.             branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
  156.             branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7)
  157.             branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
  158.             branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
  159.  
  160.             branch_pool = AveragePooling2D(
  161.                 (3, 3), strides=(1, 1), padding='same')(x)
  162.             branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
  163.             x = layers.concatenate(
  164.                 [branch1x1, branch7x7, branch7x7dbl, branch_pool],
  165.                 axis=channel_axis,
  166.                 name='mixed' + str(5 + i))
  167.         print(x)
  168.            
  169.            
  170.     with tf.device('/gpu:3'):
  171.  
  172.         # mixed 7: 17 x 17 x 768
  173.         branch1x1 = conv2d_bn(x, 192, 1, 1)
  174.  
  175.         branch7x7 = conv2d_bn(x, 192, 1, 1)
  176.         branch7x7 = conv2d_bn(branch7x7, 192, 1, 7)
  177.         branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
  178.  
  179.         branch7x7dbl = conv2d_bn(x, 192, 1, 1)
  180.         branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
  181.         branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
  182.         branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
  183.         branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
  184.  
  185.         branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
  186.         branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
  187.         x = layers.concatenate(
  188.             [branch1x1, branch7x7, branch7x7dbl, branch_pool],
  189.             axis=channel_axis,
  190.             name='mixed7')
  191.         print(x)
  192.        
  193.     with tf.device('/gpu:4'):
  194.  
  195.         # mixed 8: 8 x 8 x 1280
  196.         branch3x3 = conv2d_bn(x, 192, 1, 1)
  197.         branch3x3 = conv2d_bn(branch3x3, 320, 3, 3,
  198.                               strides=(2, 2), padding='valid')
  199.  
  200.         branch7x7x3 = conv2d_bn(x, 192, 1, 1)
  201.         branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7)
  202.         branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1)
  203.         branch7x7x3 = conv2d_bn(branch7x7x3, 192, 3, 3, strides=(2, 2), padding='valid')
  204.  
  205.         branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x)
  206.         x = layers.concatenate(
  207.             [branch3x3, branch7x7x3, branch_pool], axis=channel_axis, name='mixed8')
  208.         print(x)
  209.        
  210.     with tf.device('/gpu:5'):
  211.         # mixed 9: 8 x 8 x 2048
  212.         for i in range(2):
  213.             branch1x1 = conv2d_bn(x, 320, 1, 1)
  214.  
  215.             branch3x3 = conv2d_bn(x, 384, 1, 1)
  216.             branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3)
  217.             branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1)
  218.             branch3x3 = layers.concatenate(
  219.                 [branch3x3_1, branch3x3_2], axis=channel_axis, name='mixed9_' + str(i))
  220.  
  221.             branch3x3dbl = conv2d_bn(x, 448, 1, 1)
  222.             branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3)
  223.             branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3)
  224.             branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1)
  225.             branch3x3dbl = layers.concatenate(
  226.                 [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis)
  227.  
  228.             branch_pool = AveragePooling2D(
  229.                 (3, 3), strides=(1, 1), padding='same')(x)
  230.             branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
  231.             x = layers.concatenate(
  232.                 [branch1x1, branch3x3, branch3x3dbl, branch_pool],
  233.                 axis=channel_axis,
  234.                 name='mixed' + str(9 + i))
  235.         print(x)
  236.    
  237.     with tf.device('/gpu:6'):
  238.         x = Flatten()(x)
  239.         y = layers.Input(shape=(1,))
  240.         gender_dense = Dense(32, activation='relu',name='gender_dense')(y)
  241.    
  242.     x_merged = layers.concatenate([x,gender_dense])
  243.     print(x_merged)
  244.         # text_branch = Sequential()
  245.         # text_branch.add(Dense(32, input_shape=(1,), activation='relu'))
  246.         # merged = keras.layers.concatenate([x, text_branch.output])
  247.         # This is taking the gender as input
  248.         #
  249.     age = Dense(1000, activation='relu')(x_merged)
  250.     age = Dense(1000, activation='relu')(age)
  251.        
  252.    
  253.     age = Dense(1)(age)
  254.     print(age)
  255.         # gender_dense = Dense(32, activation='relu',name='gender_dense')(y)
  256.         #x = layers.concatenate([gender_dense,x],axis=1)
  257.         #x = Dense(1000, activation='relu', name='dense_1')(x)
  258.         #x = Dense(1000, activation='relu', name='dense_2')(x)    
  259.     # Create model.
  260.     model = Model(inputs=[img_input, y], outputs=age, name='gender_inception_v3')
  261.  
  262.  
  263.     return model
Add Comment
Please, Sign In to add comment