Source code for pytext.metric_reporters.channel

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import csv
import json
from typing import Tuple

from pytext.common.constants import Stage


[docs]class Channel: """ Channel defines how to format and report the result of a PyText job to an output stream. Attributes: stages: in which stages the report will be triggered, default is all stages, which includes train, eval, test """ def __init__( self, stages: Tuple[Stage, ...] = (Stage.TRAIN, Stage.EVAL, Stage.TEST) ) -> None: self.stages = stages
[docs] def report( self, stage, epoch, metrics, model_select_metric, loss, preds, targets, scores, context, meta, *args, ): """ Defines how to format and report data to the output channel. Args: stage (Stage): train, eval or test epoch (int): current epoch metrics (Any): all metrics model_select_metric (double): a single numeric metric to pick best model loss (double): average loss preds (List[Any]): list of predictions targets (List[Any]): list of targets scores (List[Any]): list of scores context (Dict[str, List[Any]]): dict of any additional context data, each context is a list of data that maps to each example meta (Dict[str, Any]): global metadata, such as target names """ raise NotImplementedError()
[docs]class ConsoleChannel(Channel): """ Simple Channel that prints results to console. """
[docs] def report( self, stage, epoch, metrics, model_select_metric, loss, preds, targets, scores, context, meta, *args, ): print(f"\n\n{stage}") print(f"loss: {loss:.6f}") # TODO change print_metrics function to __str__ T33522209 if hasattr(metrics, "print_metrics"): metrics.print_metrics() else: print(metrics)
[docs]class FileChannel(Channel): """ Simple Channel that writes results to a TSV file. """ def __init__(self, stages, file_path) -> None: super().__init__(stages) self.file_path = file_path
[docs] def report( self, stage, epoch, metrics, model_select_metric, loss, preds, targets, scores, context, meta, *args, ): print(f"saving result to file {self.file_path}") with open(self.file_path, "w", encoding="utf-8") as of: for metadata in meta.values(): # TODO the # prefix is quite ad-hoc, we should think of a better # way to handle it of.write("#") of.write(json.dumps(metadata)) of.write("\n") tsv_writer = csv.writer( of, delimiter="\t", quotechar='"', doublequote=True, lineterminator="\n", quoting=csv.QUOTE_MINIMAL, ) tsv_writer.writerow(self.get_title()) for row in self.gen_content( metrics, loss, preds, targets, scores, context, *args ): tsv_writer.writerow(row)
[docs] def get_title(self): return ("prediction", "target", "score")
[docs] def gen_content(self, metrics, loss, preds, targets, scores, contexts): for i in range(len(preds)): yield [preds[i], targets[i], scores[i]]
[docs]class TensorBoardChannel(Channel): """ TensorBoardChannel defines how to format and report the result of a PyText job to TensorBoard. Attributes: summary_writer: An instance of the TensorBoardX SummaryWriter class, or an object that implements the same interface. https://tensorboardx.readthedocs.io/en/latest/tensorboard.html metric_name: The name of the default metric to display on the TensorBoard dashboard, defaults to "accuracy" """ def __init__(self, summary_writer, metric_name="accuracy"): super().__init__() self.summary_writer = summary_writer self.metric_name = metric_name
[docs] def report( self, stage, epoch, metrics, model_select_metric, loss, preds, targets, scores, context, meta, *args, ): """ Defines how to format and report data to TensorBoard using the summary writer. In the current implementation, during the train/eval phase we recursively report each metric field as scalars, and during the test phase we report the final metrics to be displayed as texts. Args: stage (Stage): train, eval or test epoch (int): current epoch metrics (Any): all metrics model_select_metric (double): a single numeric metric to pick best model loss (double): average loss preds (List[Any]): list of predictions targets (List[Any]): list of targets scores (List[Any]): list of scores context (Dict[str, List[Any]]): dict of any additional context data, each context is a list of data that maps to each example meta (Dict[str, Any]): global metadata, such as target names """ if stage == Stage.TEST: tag = "test" self.summary_writer.add_text(tag, f"loss={loss}") if isinstance(metrics, (int, float)): self.summary_writer.add_text(tag, f"{self.metric_name}={metrics}") else: self.add_texts(tag, metrics) else: prefix = "train" if stage == Stage.TRAIN else "eval" self.summary_writer.add_scalar(f"{prefix}/loss", loss, epoch) if isinstance(metrics, (int, float)): self.summary_writer.add_scalar( f"{prefix}/{self.metric_name}", metrics, epoch ) else: self.add_scalars(prefix, metrics, epoch)
[docs] def add_texts(self, tag, metrics): """ Recursively flattens the metrics object and adds each field name and value as a text using the summary writer. For example, if tag = "test", and metrics = { accuracy: 0.7, scores: { precision: 0.8, recall: 0.6 } }, then under "tag=test" we will display "accuracy=0.7", and under "tag=test/scores" we will display "precision=0.8" and "recall=0.6" in TensorBoard. Args: tag (str): The tag name for the metric. If a field needs to be flattened further, it will be prepended as a prefix to the field name. metrics (Any): The metrics object. """ for field_name, field_value in metrics._asdict().items(): if isinstance(field_value, (int, float)): self.summary_writer.add_text(tag, f"{field_name}={field_value}") elif hasattr(field_value, "_asdict"): self.add_texts(f"{tag}/{field_name}", field_value)
[docs] def add_scalars(self, prefix, metrics, epoch): """ Recursively flattens the metrics object and adds each field name and value as a scalar for the corresponding epoch using the summary writer. Args: prefix (str): The tag prefix for the metric. Each field name in the metrics object will be prepended with the prefix. metrics (Any): The metrics object. """ for field_name, field_value in metrics._asdict().items(): if isinstance(field_value, (int, float)): self.summary_writer.add_scalar( f"{prefix}/{field_name}", field_value, epoch ) elif hasattr(field_value, "_asdict"): self.add_scalars(f"{prefix}/{field_name}", field_value, epoch)