Advertisement
Guest User

Untitled

a guest
Feb 21st, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.94 KB | None | 0 0
  1. class Net2:
  2. @staticmethod
  3. def build_cat_branch(inputs,category_size):
  4. x = TimeDistributed(Dense(category_size))(inputs)
  5. x = Activation('softmax', name="cat_output")(x)
  6. return x
  7.  
  8. @staticmethod
  9. def build_t_branch(inputs):
  10. x = TimeDistributed(Dense(1, activation='relu', name="t_output"))(inputs)
  11. return x
  12.  
  13. @staticmethod
  14. def build_full_model(timestep_len,hidden_size,category_size,num_features,dropout,rec_drop):
  15. inputs = Input(shape=(timestep_len,num_features),name="Input")
  16. bn = BatchNormalization()(inputs)
  17. lstm = LSTM(hidden_size, return_sequences=True, dropout=dropout, recurrent_dropout=rec_drop,name="LSTM")(bn)
  18. bn2 = BatchNormalization()(lstm)
  19. cat_branch = Net2.build_cat_branch(bn2,category_size)
  20. t_branch = Net2.build_t_branch(bn2)
  21. model = Model(inputs=inputs,outputs=[cat_branch,t_branch],name="Net2")
  22. return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement