Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """Evaluation related classes and functions."""
- import subprocess
- import abc
- import re
- import os
- import six
- import tensorflow as tf
- from tensorflow.python.summary.writer.writer_cache import FileWriterCache as SummaryWriterCache
- from opennmt.utils.misc import get_third_party_dir
- @six.add_metaclass(abc.ABCMeta)
- class ExternalEvaluator(object):
- """Base class for external evaluators."""
- def __init__(self, labels_file=None, output_dir=None, postprocess_script=None):
- self._labels_file = labels_file
- self._summary_writer = None
- self._postprocess_script = postprocess_script
- if output_dir is not None:
- self._summary_writer = SummaryWriterCache.get(output_dir)
- def __call__(self, step, predictions_path):
- """Scores the predictions and logs the result.
- Args:
- step: The step at which this evaluation occurs.
- predictions_path: The path to the saved predictions.
- """
- if self._postprocess_script:
- predictions_path_postprocessed = predictions_path+'.postprocessed'
- with open(predictions_path, 'r') as predictions_file, \
- open(predictions_path_postprocessed, 'w') as predictions_file_postprocessed:
- p = subprocess.Popen(
- self._postprocess_script,
- shell=True,
- stdin=predictions_file,
- stdout=predictions_file_postprocessed)
- p.wait()
- predictions_file_postprocessed.flush()
- predictions_path = predictions_path_postprocessed
- score = self.score(self._labels_file, predictions_path)
- if score is None:
- return
- if self._summary_writer is not None:
- self._summarize_score(step, score)
- self._log_score(score)
- return score, self.name()
- # Some evaluators may return several scores so let them the ability to
- # define how to log the score result.
- def _summarize_score(self, step, score):
- summary = tf.Summary(value=[tf.Summary.Value(
- tag="external_evaluation/{}".format(self.name()), simple_value=score)])
- self._summary_writer.add_summary(summary, step)
- def _log_score(self, score):
- tf.logging.info("%s evaluation score: %f", self.name(), score)
- @abc.abstractproperty
- def name(self):
- """Returns the name of this evaluator."""
- raise NotImplementedError()
- @abc.abstractmethod
- def score(self, labels_file, predictions_path):
- """Scores the predictions against the true output labels."""
- raise NotImplementedError()
- class BLEUEvaluator(ExternalEvaluator):
- """Evaluator calling multi-bleu.perl."""
- def _get_bleu_script(self):
- return "multi-bleu.perl"
- def name(self):
- return "BLEU"
- def score(self, labels_file, predictions_path):
- bleu_script = self._get_bleu_script()
- try:
- third_party_dir = get_third_party_dir()
- except RuntimeError as e:
- tf.logging.warning("%s", str(e))
- return None
- try:
- with open(predictions_path, "r") as predictions_file:
- bleu_out = subprocess.check_output(
- [os.path.join(third_party_dir, bleu_script), labels_file],
- stdin=predictions_file,
- stderr=subprocess.STDOUT)
- bleu_out = bleu_out.decode("utf-8")
- bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
- return float(bleu_score)
- except subprocess.CalledProcessError as error:
- if error.output is not None:
- msg = error.output.strip()
- tf.logging.warning(
- "{} script returned non-zero exit code: {}".format(bleu_script, msg))
- return None
- class BLEUDetokEvaluator(BLEUEvaluator):
- """Evaluator calling multi-bleu-detok.perl."""
- def _get_bleu_script(self):
- return "multi-bleu-detok.perl"
- def name(self):
- return "BLEU-detok"
- def score(self, labels_file, predictions_path):
- bleu_script = self._get_bleu_script()
- try:
- third_party_dir = get_third_party_dir()
- except RuntimeError as e:
- tf.logging.warning("%s", str(e))
- return None
- try:
- with open(predictions_path, "r") as predictions_file:
- bleu_out = subprocess.check_output(
- [os.path.join(third_party_dir, bleu_script), labels_file],
- stdin=predictions_file,
- stderr=subprocess.STDOUT)
- bleu_out = bleu_out.decode("utf-8")
- bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
- return float(bleu_score)
- except subprocess.CalledProcessError as error:
- if error.output is not None:
- msg = error.output.strip()
- tf.logging.warning(
- "{} script returned non-zero exit code: {}".format(bleu_script, msg))
- return None
- def external_evaluation_fn(evaluators_name, labels_file, output_dir=None, postprocess_script=None):
- """Returns a callable to be used in
- :class:`opennmt.utils.hooks.SaveEvaluationPredictionHook` that calls one or
- more external evaluators.
- Args:
- evaluators_name: An evaluator name or a list of evaluators name.
- labels_file: The true output labels.
- output_dir: The run directory.
- Returns:
- A callable or ``None`` if :obj:`evaluators_name` is ``None`` or empty.
- Raises:
- ValueError: if an evaluator name is invalid.
- """
- if evaluators_name is None:
- return None
- if not isinstance(evaluators_name, list):
- evaluators_name = [evaluators_name]
- if not evaluators_name:
- return None
- if not os.path.isfile(postprocess_script):
- raise IOError("Post process script {} not found.".format(postprocess_script))
- evaluators = []
- for name in evaluators_name:
- name = name.lower()
- if name == "bleu":
- evaluator = BLEUEvaluator(labels_file=labels_file, output_dir=output_dir, postprocess_script=postprocess_script)
- elif name == "bleu-detok":
- evaluator = BLEUDetokEvaluator(labels_file=labels_file, output_dir=output_dir, postprocess_script=postprocess_script)
- else:
- raise ValueError("No evaluator associated with the name: {}".format(name))
- evaluators.append(evaluator)
- def _post_evaluation_fn(step, predictions_path):
- scores = []
- names = []
- for evaluator in evaluators:
- score, name = evaluator(step, predictions_path)
- scores.append(score)
- names.append(name)
- return scores, names
- return _post_evaluation_fn
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement