daily pastebin goal
57%
SHARE
TWEET

Untitled

a guest Mar 19th, 2019 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from tensorflow.python.saved_model import tag_constants
  2. import tensorflow as tf
  3. import numpy as np
  4.  
  5. MAX_SEQ_LENGTH = 128
  6. EXPORT_DIR = 'export/'
  7. PREDICT_FILE = 'predict.tf_record'
  8.  
  9. # 1. Create a new input function for serving: serving_input_fn
  10. def serving_input_fn():
  11.     label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
  12.     input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids')
  13.     input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask')
  14.     segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids')
  15.     input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
  16.         'label_ids': label_ids,
  17.         'input_ids': input_ids,
  18.         'input_mask': input_mask,
  19.         'segment_ids': segment_ids,
  20.     })()
  21.     return input_fn
  22.  
  23. # 2. Export the tf.Estimator to a SavedModel using the serving_input_fn as input function
  24. estimator._export_to_tpu = False # important!
  25. estimator.export_savedmodel(EXPORT_DIR, serving_input_fn)
  26.  
  27. # 3. Instantiate the Graph using the exported SavedModel
  28. graph = tf.Graph()
  29. session = tf.Session(graph=graph)
  30. tf.saved_model.loader.load(sess, [tag_constants.SERVING], FLAGS.export_dir)
  31.  
  32. # 4. Load the PREDICT_FILE containing the input examples parsed to features.
  33. # This file is obtained by parsing the tf.Example objects using the SerializeToString() method and the TFRecordWriter.
  34. record_iterator = tf.python_io.tf_record_iterator(path=PREDICT_FILE)
  35.  
  36. # 5. Run the session for every input example
  37. tensor_input_ids = graph.get_tensor_by_name('input_ids_1:0')
  38. tensor_input_mask = graph.get_tensor_by_name('input_mask_1:0')
  39. tensor_label_ids = graph.get_tensor_by_name('label_ids_1:0')
  40. tensor_segment_ids = graph.get_tensor_by_name('segment_ids_1:0')
  41. tensor_outputs = graph.get_tensor_by_name('loss/Softmax:0')
  42.  
  43. predictions = []
  44. for string_record in record_iterator:
  45.     example = tf.train.Example()
  46.     example.ParseFromString(string_record)
  47.     input_ids = example.features.feature['input_ids'].int64_list.value
  48.     input_mask = example.features.feature['input_mask'].int64_list.value
  49.     label_ids = example.features.feature['label_ids'].int64_list.value
  50.     segment_ids = example.features.feature['segment_ids'].int64_list.value
  51.    
  52.     result = sess.run(tensor_outputs, feed_dict={
  53.         tensor_input_ids: np.array(input_ids).reshape(-1, MAX_SEQ_LENGTH),
  54.         tensor_input_mask: np.array(input_mask).reshape(-1, MAX_SEQ_LENGTH),
  55.         tensor_label_ids: np.array(label_ids),
  56.         tensor_segment_ids: np.array(segment_ids).reshape(-1, MAX_SEQ_LENGTH),
  57.     })
  58.     predictions.append(result)
  59.  
  60. # see the predictions
  61. print(predictions)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top