Advertisement
Guest User

evaluator.py

a guest
Apr 4th, 2018
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.28 KB | None | 0 0
  1. """Evaluation related classes and functions."""
  2.  
  3. import subprocess
  4.  
  5. import abc
  6. import re
  7. import os
  8. import six
  9.  
  10. import tensorflow as tf
  11.  
  12. from tensorflow.python.summary.writer.writer_cache import FileWriterCache as SummaryWriterCache
  13.  
  14. from opennmt.utils.misc import get_third_party_dir
  15.  
  16.  
  17. @six.add_metaclass(abc.ABCMeta)
  18. class ExternalEvaluator(object):
  19.   """Base class for external evaluators."""
  20.  
  21.   def __init__(self, labels_file=None, output_dir=None, postprocess_script=None):
  22.     self._labels_file = labels_file
  23.     self._summary_writer = None
  24.     self._postprocess_script = postprocess_script
  25.  
  26.     if output_dir is not None:
  27.       self._summary_writer = SummaryWriterCache.get(output_dir)
  28.  
  29.   def __call__(self, step, predictions_path):
  30.     """Scores the predictions and logs the result.
  31.  
  32.    Args:
  33.      step: The step at which this evaluation occurs.
  34.      predictions_path: The path to the saved predictions.
  35.    """
  36.     if self._postprocess_script:
  37.       predictions_path_postprocessed = predictions_path+'.postprocessed'
  38.       with open(predictions_path, 'r') as predictions_file, \
  39.               open(predictions_path_postprocessed, 'w') as predictions_file_postprocessed:
  40.         p = subprocess.Popen(
  41.               self._postprocess_script,
  42.               shell=True,
  43.               stdin=predictions_file,
  44.               stdout=predictions_file_postprocessed)
  45.         p.wait()
  46.         predictions_file_postprocessed.flush()
  47.  
  48.       predictions_path = predictions_path_postprocessed
  49.  
  50.     score = self.score(self._labels_file, predictions_path)
  51.     if score is None:
  52.       return
  53.     if self._summary_writer is not None:
  54.       self._summarize_score(step, score)
  55.     self._log_score(score)
  56.     return score, self.name()
  57.  
  58.  
  59.   # Some evaluators may return several scores so let them the ability to
  60.   # define how to log the score result.
  61.  
  62.   def _summarize_score(self, step, score):
  63.     summary = tf.Summary(value=[tf.Summary.Value(
  64.         tag="external_evaluation/{}".format(self.name()), simple_value=score)])
  65.     self._summary_writer.add_summary(summary, step)
  66.  
  67.   def _log_score(self, score):
  68.     tf.logging.info("%s evaluation score: %f", self.name(), score)
  69.  
  70.   @abc.abstractproperty
  71.   def name(self):
  72.     """Returns the name of this evaluator."""
  73.     raise NotImplementedError()
  74.  
  75.   @abc.abstractmethod
  76.   def score(self, labels_file, predictions_path):
  77.     """Scores the predictions against the true output labels."""
  78.     raise NotImplementedError()
  79.  
  80.  
  81. class BLEUEvaluator(ExternalEvaluator):
  82.   """Evaluator calling multi-bleu.perl."""
  83.  
  84.   def _get_bleu_script(self):
  85.     return "multi-bleu.perl"
  86.  
  87.   def name(self):
  88.     return "BLEU"
  89.  
  90.   def score(self, labels_file, predictions_path):
  91.     bleu_script = self._get_bleu_script()
  92.     try:
  93.       third_party_dir = get_third_party_dir()
  94.     except RuntimeError as e:
  95.       tf.logging.warning("%s", str(e))
  96.       return None
  97.     try:
  98.       with open(predictions_path, "r") as predictions_file:
  99.         bleu_out = subprocess.check_output(
  100.             [os.path.join(third_party_dir, bleu_script), labels_file],
  101.             stdin=predictions_file,
  102.             stderr=subprocess.STDOUT)
  103.         bleu_out = bleu_out.decode("utf-8")
  104.         bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
  105.         return float(bleu_score)
  106.     except subprocess.CalledProcessError as error:
  107.       if error.output is not None:
  108.         msg = error.output.strip()
  109.         tf.logging.warning(
  110.             "{} script returned non-zero exit code: {}".format(bleu_script, msg))
  111.       return None
  112.  
  113.  
  114. class BLEUDetokEvaluator(BLEUEvaluator):
  115.   """Evaluator calling multi-bleu-detok.perl."""
  116.  
  117.   def _get_bleu_script(self):
  118.     return "multi-bleu-detok.perl"
  119.  
  120.   def name(self):
  121.     return "BLEU-detok"
  122.  
  123.   def score(self, labels_file, predictions_path):
  124.     bleu_script = self._get_bleu_script()
  125.     try:
  126.       third_party_dir = get_third_party_dir()
  127.     except RuntimeError as e:
  128.       tf.logging.warning("%s", str(e))
  129.       return None
  130.     try:
  131.       with open(predictions_path, "r") as predictions_file:
  132.         bleu_out = subprocess.check_output(
  133.             [os.path.join(third_party_dir, bleu_script), labels_file],
  134.             stdin=predictions_file,
  135.             stderr=subprocess.STDOUT)
  136.         bleu_out = bleu_out.decode("utf-8")
  137.         bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
  138.         return float(bleu_score)
  139.     except subprocess.CalledProcessError as error:
  140.       if error.output is not None:
  141.         msg = error.output.strip()
  142.         tf.logging.warning(
  143.             "{} script returned non-zero exit code: {}".format(bleu_script, msg))
  144.       return None
  145.  
  146.  
  147. def external_evaluation_fn(evaluators_name, labels_file, output_dir=None, postprocess_script=None):
  148.   """Returns a callable to be used in
  149.  :class:`opennmt.utils.hooks.SaveEvaluationPredictionHook` that calls one or
  150.  more external evaluators.
  151.  
  152.  Args:
  153.    evaluators_name: An evaluator name or a list of evaluators name.
  154.    labels_file: The true output labels.
  155.    output_dir: The run directory.
  156.  
  157.  Returns:
  158.    A callable or ``None`` if :obj:`evaluators_name` is ``None`` or empty.
  159.  
  160.  Raises:
  161.    ValueError: if an evaluator name is invalid.
  162.  """
  163.   if evaluators_name is None:
  164.     return None
  165.   if not isinstance(evaluators_name, list):
  166.     evaluators_name = [evaluators_name]
  167.   if not evaluators_name:
  168.     return None
  169.   if not os.path.isfile(postprocess_script):
  170.     raise IOError("Post process script {} not found.".format(postprocess_script))
  171.  
  172.   evaluators = []
  173.   for name in evaluators_name:
  174.     name = name.lower()
  175.     if name == "bleu":
  176.       evaluator = BLEUEvaluator(labels_file=labels_file, output_dir=output_dir, postprocess_script=postprocess_script)
  177.     elif name == "bleu-detok":
  178.       evaluator = BLEUDetokEvaluator(labels_file=labels_file, output_dir=output_dir, postprocess_script=postprocess_script)
  179.     else:
  180.       raise ValueError("No evaluator associated with the name: {}".format(name))
  181.     evaluators.append(evaluator)
  182.  
  183.   def _post_evaluation_fn(step, predictions_path):
  184.     scores = []
  185.     names = []
  186.     for evaluator in evaluators:
  187.       score, name = evaluator(step, predictions_path)
  188.       scores.append(score)
  189.       names.append(name)
  190.     return scores, names
  191.  
  192.   return _post_evaluation_fn
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement