Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
128
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.60 KB | None | 0 0
  1. from tensorflow.python.saved_model import tag_constants
  2. import tensorflow as tf
  3. import numpy as np
  4.  
  5. MAX_SEQ_LENGTH = 128
  6. EXPORT_DIR = 'export/'
  7. PREDICT_FILE = 'predict.tf_record'
  8.  
  9. # 1. Create a new input function for serving: serving_input_fn
  10. def serving_input_fn():
  11. label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
  12. input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids')
  13. input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask')
  14. segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids')
  15. input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
  16. 'label_ids': label_ids,
  17. 'input_ids': input_ids,
  18. 'input_mask': input_mask,
  19. 'segment_ids': segment_ids,
  20. })()
  21. return input_fn
  22.  
  23. # 2. Export the tf.Estimator to a SavedModel using the serving_input_fn as input function
  24. estimator._export_to_tpu = False # important!
  25. estimator.export_savedmodel(EXPORT_DIR, serving_input_fn)
  26.  
  27. # 3. Instantiate the Graph using the exported SavedModel
  28. graph = tf.Graph()
  29. session = tf.Session(graph=graph)
  30. tf.saved_model.loader.load(sess, [tag_constants.SERVING], FLAGS.export_dir)
  31.  
  32. # 4. Load the PREDICT_FILE containing the input examples parsed to features.
  33. # This file is obtained by parsing the tf.Example objects using the SerializeToString() method and the TFRecordWriter.
  34. record_iterator = tf.python_io.tf_record_iterator(path=PREDICT_FILE)
  35.  
  36. # 5. Run the session for every input example
  37. tensor_input_ids = graph.get_tensor_by_name('input_ids_1:0')
  38. tensor_input_mask = graph.get_tensor_by_name('input_mask_1:0')
  39. tensor_label_ids = graph.get_tensor_by_name('label_ids_1:0')
  40. tensor_segment_ids = graph.get_tensor_by_name('segment_ids_1:0')
  41. tensor_outputs = graph.get_tensor_by_name('loss/Softmax:0')
  42.  
  43. predictions = []
  44. for string_record in record_iterator:
  45. example = tf.train.Example()
  46. example.ParseFromString(string_record)
  47. input_ids = example.features.feature['input_ids'].int64_list.value
  48. input_mask = example.features.feature['input_mask'].int64_list.value
  49. label_ids = example.features.feature['label_ids'].int64_list.value
  50. segment_ids = example.features.feature['segment_ids'].int64_list.value
  51.  
  52. result = sess.run(tensor_outputs, feed_dict={
  53. tensor_input_ids: np.array(input_ids).reshape(-1, MAX_SEQ_LENGTH),
  54. tensor_input_mask: np.array(input_mask).reshape(-1, MAX_SEQ_LENGTH),
  55. tensor_label_ids: np.array(label_ids),
  56. tensor_segment_ids: np.array(segment_ids).reshape(-1, MAX_SEQ_LENGTH),
  57. })
  58. predictions.append(result)
  59.  
  60. # see the predictions
  61. print(predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement