Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from tensorflow.python.saved_model import tag_constants
- import tensorflow as tf
- import numpy as np
- MAX_SEQ_LENGTH = 128
- EXPORT_DIR = 'export/'
- PREDICT_FILE = 'predict.tf_record'
- # 1. Create a new input function for serving: serving_input_fn
- def serving_input_fn():
- label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
- input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids')
- input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask')
- segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids')
- input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
- 'label_ids': label_ids,
- 'input_ids': input_ids,
- 'input_mask': input_mask,
- 'segment_ids': segment_ids,
- })()
- return input_fn
- # 2. Export the tf.Estimator to a SavedModel using the serving_input_fn as input function
- estimator._export_to_tpu = False # important!
- estimator.export_savedmodel(EXPORT_DIR, serving_input_fn)
- # 3. Instantiate the Graph using the exported SavedModel
- graph = tf.Graph()
- session = tf.Session(graph=graph)
- tf.saved_model.loader.load(sess, [tag_constants.SERVING], FLAGS.export_dir)
- # 4. Load the PREDICT_FILE containing the input examples parsed to features.
- # This file is obtained by parsing the tf.Example objects using the SerializeToString() method and the TFRecordWriter.
- record_iterator = tf.python_io.tf_record_iterator(path=PREDICT_FILE)
- # 5. Run the session for every input example
- tensor_input_ids = graph.get_tensor_by_name('input_ids_1:0')
- tensor_input_mask = graph.get_tensor_by_name('input_mask_1:0')
- tensor_label_ids = graph.get_tensor_by_name('label_ids_1:0')
- tensor_segment_ids = graph.get_tensor_by_name('segment_ids_1:0')
- tensor_outputs = graph.get_tensor_by_name('loss/Softmax:0')
- predictions = []
- for string_record in record_iterator:
- example = tf.train.Example()
- example.ParseFromString(string_record)
- input_ids = example.features.feature['input_ids'].int64_list.value
- input_mask = example.features.feature['input_mask'].int64_list.value
- label_ids = example.features.feature['label_ids'].int64_list.value
- segment_ids = example.features.feature['segment_ids'].int64_list.value
- result = sess.run(tensor_outputs, feed_dict={
- tensor_input_ids: np.array(input_ids).reshape(-1, MAX_SEQ_LENGTH),
- tensor_input_mask: np.array(input_mask).reshape(-1, MAX_SEQ_LENGTH),
- tensor_label_ids: np.array(label_ids),
- tensor_segment_ids: np.array(segment_ids).reshape(-1, MAX_SEQ_LENGTH),
- })
- predictions.append(result)
- # see the predictions
- print(predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement