Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Net2:
- @staticmethod
- def build_cat_branch(inputs,category_size):
- x = TimeDistributed(Dense(category_size))(inputs)
- x = Activation('softmax', name="cat_output")(x)
- return x
- @staticmethod
- def build_t_branch(inputs):
- x = TimeDistributed(Dense(1, activation='relu', name="t_output"))(inputs)
- return x
- @staticmethod
- def build_full_model(timestep_len,hidden_size,category_size,num_features,dropout,rec_drop):
- inputs = Input(shape=(timestep_len,num_features),name="Input")
- bn = BatchNormalization()(inputs)
- lstm = LSTM(hidden_size, return_sequences=True, dropout=dropout, recurrent_dropout=rec_drop,name="LSTM")(bn)
- bn2 = BatchNormalization()(lstm)
- cat_branch = Net2.build_cat_branch(bn2,category_size)
- t_branch = Net2.build_t_branch(bn2)
- model = Model(inputs=inputs,outputs=[cat_branch,t_branch],name="Net2")
- return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement