Source code for

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List

from pytext.config.contextual_intent_slot import (
from pytext.config.field_config import DocLabelConfig, WordLabelConfig
from import InputRecord
from pytext.fields import (
from pytext.utils import data_utils

from .joint_data_handler import JointModelDataHandler

[docs]class RawData: DOC_LABEL = "doc_label" WORD_LABEL = "word_label" TEXT = "text" DICT_FEAT = "dict_feat" DOC_WEIGHT = "doc_weight" WORD_WEIGHT = "word_weight"
[docs]class ContextualIntentSlotModelDataHandler(JointModelDataHandler): """ Data Handler to build pipeline to process data and generate tensors to be consumed by ContextualIntentSlotModel. Columns of Input data includes: 1. doc label for intent classification 2. word label for slot tagging of the last utterance 3. a sequence of utterances (e.g., a dialog) 4. Optional dictionary feature contained in the last utterance 5. Optional doc weight that stands for the weight of intent task in joint loss. 6. Optional word weight that stands for the weight of slot task in joint loss. Attributes: raw_columns: columns to read from data source. In case of files, the order should match the data stored in that file. Raw columns include :: [ RawData.DOC_LABEL, RawData.WORD_LABEL, RawData.TEXT, RawData.DICT_FEAT (Optional), RawData.DOC_WEIGHT (Optional), RawData.WORD_WEIGHT (Optional), ] labels: doc labels and word labels features: embeddings generated from sequences of utterances and dictionary features of the last utterance extra_fields: doc weights, word weights, and etc. """
[docs] class Config(JointModelDataHandler.Config): columns_to_read: List[str] = [ RawData.DOC_LABEL, RawData.WORD_LABEL, RawData.TEXT, RawData.DICT_FEAT, RawData.DOC_WEIGHT, RawData.WORD_WEIGHT, ]
[docs] @classmethod def from_config( cls, config: Config, feature_config: ModelInputConfig, target_config: TargetConfig, **kwargs, ): """Factory method to construct an instance of ContextualIntentSlotModelDataHandler object from the module's config, model input config and target config. Args: config (Config): Configuration object specifying all the parameters of ContextualIntentSlotModelDataHandler. feature_config (ModelInputConfig): Configuration object specifying model input. target_config (TargetConfig): Configuration object specifying target. Returns: type: An instance of ContextualIntentSlotModelDataHandler. """ features: Dict[str, Field] = create_fields( feature_config, { ModelInput.TEXT: TextFeatureField, ModelInput.DICT: DictFeatureField, ModelInput.CHAR: CharFeatureField, ModelInput.PRETRAINED: PretrainedModelEmbeddingField, ModelInput.SEQ: SeqFeatureField, }, ) # Label fields. labels: Dict[str, Field] = create_label_fields( target_config, { DocLabelConfig._name: DocLabelField, WordLabelConfig._name: WordLabelField, }, ) extra_fields: Dict[str, Field] = { ExtraField.DOC_WEIGHT: FloatField(), ExtraField.WORD_WEIGHT: FloatField(), ExtraField.RAW_WORD_LABEL: RawField(), ExtraField.TOKEN_RANGE: RawField(), ExtraField.UTTERANCE: RawField(), } kwargs.update(config.items()) return cls( raw_columns=config.columns_to_read, labels=labels, features=features, extra_fields=extra_fields, **kwargs, )
[docs] def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]: """Preprocess steps for a single input row: 1. apply tokenization to a sequence of utterances; 2. process dictionary features to align with the last utterance. 3. align word labels with the last utterance. Args: row_data (Dict[str, Any]): Dict of one row data with column names as keys. Keys includes "doc_label", "word_label", "text", "dict_feat", "word weight" and "doc weight". Returns: Dict[str, Any]: Preprocessed dict of one row data includes: "seq_word_feat" (list of list of string) tokenized words of sequence of utterances "word_feat" (list of string) tokenized words of last utterance "raw_word_label" (string) raw word label "token_range" (list of tuple) token ranges of word labels, each tuple contains the start position index and the end position index "utterance" (list of string) raw utterances "word_label" (list of string) list of labels of words in last utterance "doc_label" (string) doc label for intent classification "word_weight" (float) weight of word label "doc_weight" (float) weight of document label "dict_feat" (tuple, optional) tuple of three lists, the first is the label of each words, the second is the weight of the feature, the third is the length of the feature. """ sequence = data_utils.parse_json_array(row_data[RawData.TEXT]) # ignore dictionary feature for context sentences other than the last one features_list = [ self.featurizer.featurize(InputRecord(raw_text=utterance)) for utterance in sequence[:-1] ] # adding dictionary feature for the last (current) message features_list.append( self.featurizer.featurize( InputRecord( raw_text=sequence[-1], raw_gazetteer_feats=row_data.get(ModelInput.DICT, ""), ) ) ) res = { # features ModelInput.SEQ: [utterance.tokens for utterance in features_list], ModelInput.TEXT: features_list[-1].tokens, ModelInput.DICT: ( features_list[-1].gazetteer_feats, features_list[-1].gazetteer_feat_weights, features_list[-1].gazetteer_feat_lengths, ), ModelInput.CHAR: features_list[-1].characters, ModelInput.PRETRAINED: features_list[-1].pretrained_token_embedding, # labels DocLabelConfig._name: row_data[RawData.DOC_LABEL], # extra data # TODO move the logic to FloatField ExtraField.DOC_WEIGHT: row_data.get(RawData.DOC_WEIGHT) or 1.0, ExtraField.WORD_WEIGHT: row_data.get(RawData.WORD_WEIGHT) or 1.0, ExtraField.RAW_WORD_LABEL: row_data[RawData.WORD_LABEL], ExtraField.UTTERANCE: row_data[RawData.TEXT], ExtraField.TOKEN_RANGE: features_list[-1].token_ranges, } if WordLabelConfig._name in self.labels: # TODO move it into word label field res[WordLabelConfig._name] = data_utils.align_slot_labels( features_list[-1].token_ranges, row_data[RawData.WORD_LABEL], self.labels[WordLabelConfig._name].use_bio_labels, ) return res
def _train_input_from_batch(self, batch): text_input = getattr(batch, ModelInput.TEXT) seq_input = getattr(batch, ModelInput.SEQ) return ( # text_input[0] contains the word embeddings, # text_input[1] contains the lengths of each word text_input[0], *( getattr(batch, key) for key in self.features if key not in [ModelInput.TEXT, ModelInput.SEQ] ), seq_input[0], text_input[1], seq_input[1], )