Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import numpy as np
- import tensorflow as tf
- from tensorflow.keras.layers import *
- from sklearn.model_selection import train_test_split
- folder_name = "tfNX"
- main_path = "./" + folder_name
- checkpoint_path = main_path + "/chkpts/"
- """Create Dirs if they don't exist:"""
- #if not os.path.exists(main_path):
- # os.mkdir(main_path)
- #for folder in ['chkpts', 'graph', 'dataset', 'weights', 'model', 'frozen', 'onnx', 'plan', 'trt']:
- # if not os.path.exists(main_path + '/' + folder):
- # os.mkdir(main_path + '/' + folder)
- loadweights, trainmodel = False, False
- lrate = 0.01
- bs, eps = 16, 39
- stage1, stage2 = 7, 5
- g = tf.Graph()
- sess = tf.compat.v1.Session(graph=g)
- with g.as_default():
- inptensor = Input(shape=(7, 96, 2, 2), dtype='float32', name='input')
- bd10 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(inptensor)
- bd20 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd10)
- bd30 = Bidirectional(ConvLSTM2D(2, (3, 3),
- return_sequences=False,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd20)
- dens20 = Dense(2)(bd30)
- dens30 = Dense(1)(dens20)
- output0 = Reshape((96, 2), name='output0')(dens30)
- bd11 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(inptensor)
- bd21 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd11)
- bd31 = Bidirectional(ConvLSTM2D(2, (3, 3),
- return_sequences=False,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd21)
- dens21 = Dense(2)(bd31)
- dens31 = Dense(1)(dens21)
- output1 = Reshape((96, 2), name='output1')(dens31)
- bd12 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(inptensor)
- bd22 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd12)
- bd32 = Bidirectional(ConvLSTM2D(2, (3, 3),
- return_sequences=False,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd22)
- dens22 = Dense(2)(bd32)
- dens32 = Dense(1)(dens22)
- output2 = Reshape((96, 2), name='output2')(dens32)
- bd13 = Bidirectional(ConvLSTM2D(2, (stage1, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(inptensor)
- bd23 = Bidirectional(ConvLSTM2D(2, (stage2, 3),
- return_sequences=True,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd13)
- bd33 = Bidirectional(ConvLSTM2D(2, (3, 3),
- return_sequences=False,
- padding='same',
- go_backwards=True,
- activation='tanh',
- recurrent_activation='sigmoid'))(bd23)
- dens23 = Dense(2)(bd33)
- dens33 = Dense(1)(dens23)
- output3 = Reshape((96, 2), name='output3')(dens33)
- model = tf.keras.Model(inptensor, (output0, output1, output2, output3), name='Xtract0r')
- model.compile(optimizer=tf.optimizers.SGD(learning_rate=lrate, decay=1e-6, momentum=0.9, nesterov=True),
- loss=tf.losses.MeanSquaredError(),
- metrics=[tf.keras.metrics.RootMeanSquaredError()])
- assert output0.graph is g and output1.graph is g and output2.graph is g and output3.graph is g
- if loadweights:
- model.load_weights(checkpoint_path).assert_existing_objects_matched()
- if trainmodel:
- history = model.fit(...)
- model.save(filepath=main_path + '/model/', save_format='tf')
- with sess.as_default():
- with g.as_default():
- init = tf.compat.v1.global_variables_initializer()
- sess.run(init)
- tf.io.write_graph(graph_or_graph_def=g.as_graph_def(), logdir=main_path + '/graph', name='graph.pb', as_text=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement