Advertisement
TylerHumanCompiler

model.py

Jul 9th, 2020
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.30 KB | None | 0 0
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow.keras.layers import *
  5. from sklearn.model_selection import train_test_split
  6.  
  7.  
  8.  
  9. folder_name = "tfNX"
  10. main_path = "./" + folder_name
  11. checkpoint_path = main_path + "/chkpts/"
  12.  
  13. """Create Dirs if they don't exist:"""
  14. #if not os.path.exists(main_path):
  15. #    os.mkdir(main_path)
  16. #for folder in ['chkpts', 'graph', 'dataset', 'weights', 'model', 'frozen', 'onnx', 'plan', 'trt']:
  17. #    if not os.path.exists(main_path + '/' + folder):
  18. #        os.mkdir(main_path + '/' + folder)
  19.  
  20. loadweights, trainmodel = False, False
  21. lrate = 0.01
  22. bs, eps = 16, 39
  23. stage1, stage2 = 7, 5
  24.  
  25. g = tf.Graph()
  26. sess = tf.compat.v1.Session(graph=g)
  27. with g.as_default():
  28.     inptensor = Input(shape=(7, 96, 2, 2), dtype='float32', name='input')
  29.  
  30.     bd10 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
  31.                                     return_sequences=True,
  32.                                     padding='same',
  33.                                     go_backwards=True,
  34.                                     activation='tanh',
  35.                                     recurrent_activation='sigmoid'))(inptensor)
  36.     bd20 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
  37.                                     return_sequences=True,
  38.                                     padding='same',
  39.                                     go_backwards=True,
  40.                                     activation='tanh',
  41.                                     recurrent_activation='sigmoid'))(bd10)
  42.     bd30 = Bidirectional(ConvLSTM2D(2, (3, 3),
  43.                                     return_sequences=False,
  44.                                     padding='same',
  45.                                     go_backwards=True,
  46.                                     activation='tanh',
  47.                                     recurrent_activation='sigmoid'))(bd20)
  48.     dens20 = Dense(2)(bd30)
  49.     dens30 = Dense(1)(dens20)
  50.     output0 = Reshape((96, 2), name='output0')(dens30)
  51.  
  52.     bd11 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
  53.                                     return_sequences=True,
  54.                                     padding='same',
  55.                                     go_backwards=True,
  56.                                     activation='tanh',
  57.                                     recurrent_activation='sigmoid'))(inptensor)
  58.     bd21 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
  59.                                     return_sequences=True,
  60.                                     padding='same',
  61.                                     go_backwards=True,
  62.                                     activation='tanh',
  63.                                     recurrent_activation='sigmoid'))(bd11)
  64.     bd31 = Bidirectional(ConvLSTM2D(2, (3, 3),
  65.                                     return_sequences=False,
  66.                                     padding='same',
  67.                                     go_backwards=True,
  68.                                     activation='tanh',
  69.                                     recurrent_activation='sigmoid'))(bd21)
  70.     dens21 = Dense(2)(bd31)
  71.     dens31 = Dense(1)(dens21)
  72.     output1 = Reshape((96, 2), name='output1')(dens31)
  73.  
  74.     bd12 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
  75.                                     return_sequences=True,
  76.                                     padding='same',
  77.                                     go_backwards=True,
  78.                                     activation='tanh',
  79.                                     recurrent_activation='sigmoid'))(inptensor)
  80.     bd22 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
  81.                                     return_sequences=True,
  82.                                     padding='same',
  83.                                     go_backwards=True,
  84.                                     activation='tanh',
  85.                                     recurrent_activation='sigmoid'))(bd12)
  86.     bd32 = Bidirectional(ConvLSTM2D(2, (3, 3),
  87.                                     return_sequences=False,
  88.                                     padding='same',
  89.                                     go_backwards=True,
  90.                                     activation='tanh',
  91.                                     recurrent_activation='sigmoid'))(bd22)
  92.     dens22 = Dense(2)(bd32)
  93.     dens32 = Dense(1)(dens22)
  94.     output2 = Reshape((96, 2), name='output2')(dens32)
  95.  
  96.     bd13 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
  97.                                     return_sequences=True,
  98.                                     padding='same',
  99.                                     go_backwards=True,
  100.                                     activation='tanh',
  101.                                     recurrent_activation='sigmoid'))(inptensor)
  102.     bd23 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
  103.                                     return_sequences=True,
  104.                                     padding='same',
  105.                                     go_backwards=True,
  106.                                     activation='tanh',
  107.                                     recurrent_activation='sigmoid'))(bd13)
  108.     bd33 = Bidirectional(ConvLSTM2D(2, (3, 3),
  109.                                     return_sequences=False,
  110.                                     padding='same',
  111.                                     go_backwards=True,
  112.                                     activation='tanh',
  113.                                     recurrent_activation='sigmoid'))(bd23)
  114.     dens23 = Dense(2)(bd33)
  115.     dens33 = Dense(1)(dens23)
  116.     output3 = Reshape((96, 2), name='output3')(dens33)
  117.  
  118.     model = tf.keras.Model(inptensor, (output0, output1, output2, output3), name='Xtract0r')
  119.     model.compile(optimizer=tf.optimizers.SGD(learning_rate=lrate, decay=1e-6, momentum=0.9, nesterov=True),
  120.                   loss=tf.losses.MeanSquaredError(),
  121.                   metrics=[tf.keras.metrics.RootMeanSquaredError()])
  122.     assert output0.graph is g and output1.graph is g and output2.graph is g and output3.graph is g
  123.  
  124.     if loadweights:
  125.         model.load_weights(checkpoint_path).assert_existing_objects_matched()
  126.     if trainmodel:
  127.         history = model.fit(...)
  128.     model.save(filepath=main_path + '/model/', save_format='tf')
  129.  
  130. with sess.as_default():
  131.     with g.as_default():
  132.         init = tf.compat.v1.global_variables_initializer()
  133.         sess.run(init)
  134.         tf.io.write_graph(graph_or_graph_def=g.as_graph_def(), logdir=main_path + '/graph', name='graph.pb', as_text=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement