Guest User

Untitled

a guest
Jul 14th, 2020
532
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.98 KB | None | 0 0
  1. import keras_metrics as km
  2. import keras.metrics as metrics
  3. from keras import optimizers
  4. from keras.models import Model
  5. from keras.layers import Conv3D, BatchNormalization, Activation, Input, MaxPooling3D, concatenate, UpSampling3D
  6. from losses_and_metrics.keras_weighted_categorical_crossentropy import weighted_categorical_crossentropy
  7. from losses_and_metrics.dsc import dice_coef_label
  8.  
  9.  
  10. def detection_unet(filters, kernel_size, weights, learning_rate):
  11.  
  12. # Input
  13. main_input = Input(shape=(None, None, None, 1))
  14.  
  15. # 64 x 64 x 80
  16. step_down_1 = Conv3D(filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(main_input)
  17. step_down_1 = BatchNormalization(momentum=0.1)(step_down_1)
  18. step_down_1 = Activation("relu")(step_down_1)
  19. step_down_1 = Conv3D(filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_1)
  20. step_down_1 = BatchNormalization(momentum=0.1)(step_down_1)
  21. step_down_1 = Activation("relu")(step_down_1)
  22.  
  23. # 32 x 32 x 40
  24. step_down_2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(step_down_1)
  25. step_down_2 = Conv3D(2 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_2)
  26. step_down_2 = BatchNormalization(momentum=0.1)(step_down_2)
  27. step_down_2 = Activation("relu")(step_down_2)
  28. step_down_2 = Conv3D(2 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_2)
  29. step_down_2 = BatchNormalization(momentum=0.1)(step_down_2)
  30. step_down_2 = Activation("relu")(step_down_2)
  31.  
  32. # 16 x 16 x 20
  33. step_down_3 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(step_down_2)
  34. step_down_3 = Conv3D(4 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_3)
  35. step_down_3 = BatchNormalization(momentum=0.1)(step_down_3)
  36. step_down_3 = Activation("relu")(step_down_3)
  37. step_down_3 = Conv3D(4 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_3)
  38. step_down_3 = BatchNormalization(momentum=0.1)(step_down_3)
  39. step_down_3 = Activation("relu")(step_down_3)
  40.  
  41. # 8 x 8 x 10
  42. step_down_4 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(step_down_3)
  43. step_down_4 = Conv3D(8 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_4)
  44. step_down_4 = BatchNormalization(momentum=0.1)(step_down_4)
  45. step_down_4 = Activation("relu")(step_down_4)
  46. step_down_4 = Conv3D(8 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_down_4)
  47. step_down_4 = BatchNormalization(momentum=0.1)(step_down_4)
  48. step_down_4 = Activation("relu")(step_down_4)
  49.  
  50. # 4 x 4 x 5
  51. floor = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(step_down_4)
  52. floor = Conv3D(16 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(floor)
  53. floor = BatchNormalization(momentum=0.1)(floor)
  54. floor = Activation("relu")(floor)
  55. floor = Conv3D(16 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(floor)
  56. floor = BatchNormalization(momentum=0.1)(floor)
  57. floor = Activation("relu")(floor)
  58.  
  59. # 8 x 8 x 10
  60. step_up_4 = UpSampling3D(size=(2, 2, 2))(floor)
  61. step_up_4 = concatenate([step_down_4, step_up_4], axis=-1)
  62. step_up_4 = Conv3D(8 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_4)
  63. step_up_4 = BatchNormalization(momentum=0.1)(step_up_4)
  64. step_up_4 = Activation("relu")(step_up_4)
  65. step_up_4 = Conv3D(8 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_4)
  66. step_up_4 = BatchNormalization(momentum=0.1)(step_up_4)
  67. step_up_4 = Activation("relu")(step_up_4)
  68.  
  69. # 16 x 16 x 20
  70. step_up_3 = UpSampling3D(size=(2, 2, 2))(step_up_4)
  71. step_up_3 = concatenate([step_down_3, step_up_3], axis=-1)
  72. step_up_3 = Conv3D(4 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_3)
  73. step_up_3 = BatchNormalization(momentum=0.1)(step_up_3)
  74. step_up_3 = Activation("relu")(step_up_3)
  75. step_up_3 = Conv3D(4 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_3)
  76. step_up_3 = BatchNormalization(momentum=0.1)(step_up_3)
  77. step_up_3 = Activation("relu")(step_up_3)
  78.  
  79. # 32 x 32 x 40
  80. step_up_2 = UpSampling3D(size=(2, 2, 2))(step_up_3)
  81. step_up_2 = concatenate([step_down_2, step_up_2], axis=-1)
  82. step_up_2 = Conv3D(2 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_2)
  83. step_up_2 = BatchNormalization(momentum=0.1)(step_up_2)
  84. step_up_2 = Activation("relu")(step_up_2)
  85. step_up_2 = Conv3D(2 * filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_2)
  86. step_up_2 = BatchNormalization(momentum=0.1)(step_up_2)
  87. step_up_2 = Activation("relu")(step_up_2)
  88.  
  89. # 64 x 64 x 80
  90. step_up_1 = UpSampling3D(size=(2, 2, 2))(step_up_2)
  91. step_up_1 = concatenate([step_down_1, step_up_1], axis=-1)
  92. step_up_1 = Conv3D(filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_1)
  93. step_up_1 = BatchNormalization(momentum=0.1)(step_up_1)
  94. step_up_1 = Activation("relu")(step_up_1)
  95. step_up_1 = Conv3D(filters, kernel_size=kernel_size, strides=(1, 1, 1), padding="same")(step_up_1)
  96. step_up_1 = BatchNormalization(momentum=0.1)(step_up_1)
  97. step_up_1 = Activation("relu")(step_up_1)
  98.  
  99. main_output = Conv3D(2, kernel_size=kernel_size, strides=(1, 1, 1), padding="same",
  100. activation='softmax')(step_up_1)
  101.  
  102. model = Model(inputs=main_input, outputs=main_output)
  103.  
  104. # define optimizer
  105. adam = optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=1e-6)
  106.  
  107. # define loss function
  108. loss_function = weighted_categorical_crossentropy(weights)
  109.  
  110. # define metrics
  111. dsc = dice_coef_label(label=1)
  112. recall_background = km.binary_recall(label=0)
  113. recall_vertebrae = km.binary_recall(label=1)
  114. cat_accuracy = metrics.categorical_accuracy
  115.  
  116. model.compile(optimizer=adam, loss=loss_function, metrics=[dsc, recall_background, recall_vertebrae, cat_accuracy])
  117.  
  118. return model
Add Comment
Please, Sign In to add comment