Skip to content

Commit

Permalink
Remove generic tokenizer and support multiple languages for the word …
Browse files Browse the repository at this point in the history
…cloud. (#388)

* Refactor top words

* Update documentation

* Remove tokenizer and dummy saliency

* Adapt based on comments

* Move reg ex

* Create TopWords config scope

* Apply suggestions from code review

Co-authored-by: Lindsay Brin <[email protected]>

Co-authored-by: Lindsay Brin <[email protected]>
  • Loading branch information
gabegma and lindsaydbrin authored Jan 26, 2023
1 parent e8849f3 commit 1e0c33e
Show file tree
Hide file tree
Showing 19 changed files with 111 additions and 351 deletions.
6 changes: 5 additions & 1 deletion azimuth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,17 @@ class SyntaxConfig(CommonFieldsConfig):
syntax: SyntaxOptions = SyntaxOptions()


class TopWordsConfig(SyntaxConfig, ModelContractConfig):
pass


class AzimuthConfig(
MetricConfig,
PerturbationTestingConfig,
SimilarityConfig,
DatasetWarningConfig,
SyntaxConfig,
LanguageConfig,
TopWordsConfig,
extra=Extra.forbid,
):
# Reminder: If a module depends on an attribute in AzimuthConfig, the module will be forced to
Expand Down
6 changes: 0 additions & 6 deletions azimuth/modules/base_classes/artifact_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, Dict, Optional

from datasets import DatasetDict
from transformers import AutoTokenizer

from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
Expand Down Expand Up @@ -117,11 +116,6 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):

return self.models_mapping[config_key][pipeline_idx]

def get_tokenizer(self):
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
return self.tokenizer

def get_metric(self, config, name: str, **kwargs):
hash: Hash = md5_hash({"name": name, **kwargs})
if hash not in self.metrics:
Expand Down
2 changes: 1 addition & 1 deletion azimuth/modules/model_contracts/hf_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
"""
if self.saliency_layer is None:
return self.empty_saliency_from_batch(batch)
raise ValueError("This method should not be called when saliency_layer is not defined.")

pipeline = self.get_model()

Expand Down
26 changes: 0 additions & 26 deletions azimuth/modules/model_contracts/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,29 +305,3 @@ def get_postprocessed_output(
postprocessed_output = postprocessed_steps[-1].output
# Preprocessing steps are not supported at the moment for HF pipelines
return model_out_formatted, postprocessed_output, [], postprocessed_steps

def empty_saliency_from_batch(self, batch) -> List[SaliencyResponse]:
"""Return dummy output to not break API consumers.
Args:
batch: Utterances.
Returns:
Saliencies of 0.
"""
records: List[SaliencyResponse] = []
tokenizer = self.artifact_manager.get_tokenizer()
for utterance in batch[self.config.columns.text_input]:
token_ids = tokenizer(utterance)["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [
token for token in tokens if token not in [tokenizer.cls_token, tokenizer.sep_token]
]

json_output = SaliencyResponse(
saliency=[0.0] * len(tokens),
tokens=tokens,
)
records.append(json_output)
return records
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

class TextClassificationNoSaliencyModule(TextClassificationModule):
def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
return self.empty_saliency_from_batch(batch)
raise NotImplementedError
5 changes: 3 additions & 2 deletions azimuth/modules/utilities/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from azimuth.types import ModuleOptions, SupportedMethod, SupportedModelContract
from azimuth.types.validation import ValidationResponse
from azimuth.utils.logs import MultipleExceptions
from azimuth.utils.project import predictions_available, saliency_available
from azimuth.utils.validation import assert_not_none


Expand Down Expand Up @@ -40,7 +41,7 @@ def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore

model = (
exception_gatherer.try_calling_function(self.get_model)
if self.config.pipelines is not None
if predictions_available(self.config)
else None
)
can_load_model = model is not None
Expand All @@ -61,7 +62,7 @@ def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore
exception_gatherer.try_calling_function(self._validate_prediction, batch=batch)
is not None
)
if can_make_prediction:
if can_make_prediction and saliency_available(self.config):
can_make_saliency = (
exception_gatherer.try_calling_function(self._validate_saliency, batch=batch)
is not None
Expand Down
81 changes: 42 additions & 39 deletions azimuth/modules/word_analysis/top_words.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright ServiceNow, Inc. 2021 – 2022
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
# in the root directory of this source tree.
import string
from collections import Counter
from typing import List, Tuple
from typing import Dict, List, Tuple

import numpy as np
import spacy

from azimuth.config import ModelContractConfig
from azimuth.config import TopWordsConfig
from azimuth.modules.base_classes import FilterableModule
from azimuth.modules.task_execution import get_task_result
from azimuth.modules.word_analysis.tokens_to_words import TokensToWordsModule
Expand All @@ -19,21 +19,20 @@
TopWordsResult,
)
from azimuth.utils.dataset_operations import get_predictions_from_ds
from azimuth.utils.ml.third_parties.stop_words import STOP_WORDS
from azimuth.utils.project import saliency_available
from azimuth.utils.utterance import clean_utterance

MIN_SALIENCY = 0.01


class TopWordsModule(FilterableModule[ModelContractConfig]):
class TopWordsModule(FilterableModule[TopWordsConfig]):
"""Returns the most important words in terms of their saliency value or frequency."""

allowed_mod_options = FilterableModule.allowed_mod_options | {
"top_x",
"th_importance",
"force_no_saliency",
}
stop_words_punctuation = list(string.punctuation) + STOP_WORDS

@staticmethod
def count_words(list_of_words: List[str], top_x: int) -> List[TopWordsResult]:
Expand Down Expand Up @@ -102,44 +101,48 @@ def compute_on_dataset_split(self) -> List[TopWordsResponse]: # type: ignore
)
]

# Saliencies will be 0 if saliency maps are not available.
words_saliencies = self.get_words_saliencies(self.get_indices())

important_words_all = []
important_words_right = []
important_words_errors = []
important_words_per_idx: Dict[int, List[str]] = {}
if importance_criteria == TopWordsImportanceCriteria.salient:
words_saliencies = self.get_words_saliencies(self.get_indices())
tokenizer = self.get_model().tokenizer
for idx, record in enumerate(words_saliencies):
# Put everything to lower case and remove cls/sep tokens.
words, saliencies = zip(
*[
(word.lower(), saliency_value)
for word, saliency_value in zip(record.words, record.saliency)
if word not in [tokenizer.cls_token, tokenizer.sep_token]
]
)
if words:
importance_saliency = max(
self.mod_options.th_importance * max(record.saliency), MIN_SALIENCY
)
important_words_per_idx[idx] = [
word
for word, _ in filter(
lambda s: s[1] > importance_saliency, zip(words, saliencies)
)
]
# If saliency is not available, we proxy important words as any word that is neither
# punctuation nor a stop word.
else:
spacy_model = spacy.load(self.config.syntax.spacy_model)
utterances = ds[self.config.columns.text_input]
for idx, utterance in enumerate(utterances):
doc = spacy_model(clean_utterance(utterance))
important_words_per_idx[idx] = [
token.text for token in doc if not token.is_stop and not token.is_punct
]

tokenizer = self.artifact_manager.get_tokenizer()
is_error = np.array(
get_predictions_from_ds(ds, self.mod_options.without_postprocessing)
) != np.array(ds[self.config.columns.label])

for idx, record in enumerate(words_saliencies):
# Put everything to lower case and remove cls/sep tokens.
words, saliencies = zip(
*[
(word.lower(), saliency_value)
for word, saliency_value in zip(record.words, record.saliency)
if word not in [tokenizer.cls_token, tokenizer.sep_token]
]
)
if importance_criteria == TopWordsImportanceCriteria.salient and words != []:
th_importance = self.mod_options.th_importance
assert len(words) == len(saliencies)
importance_saliency = max(th_importance * max(record.saliency), MIN_SALIENCY)
important_words = [
word
for word, _ in filter(
lambda s: s[1] > importance_saliency, zip(words, saliencies)
)
]
# If saliency is not available, we proxy important words as any word that is neither
# punctuation or a stop word.
else:
important_words = [
word for word in words if word not in self.stop_words_punctuation
]

important_words_all = []
important_words_errors = []
important_words_right = []
for idx, important_words in important_words_per_idx.items():
important_words_all.extend(important_words)
if is_error[idx]:
important_words_errors.extend(important_words)
Expand Down
8 changes: 7 additions & 1 deletion azimuth/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
perturbation_testing_available,
postprocessing_editable,
predictions_available,
saliency_available,
similarity_available,
)
from azimuth.utils.validation import assert_not_none
Expand Down Expand Up @@ -101,9 +102,12 @@ class Startup:
)
]

SALIENCY_TASKS = [
Startup("saliency", SupportedMethod.Saliency, run_on_all_pipelines=True),
]

BASE_PREDICTION_TASKS = [
Startup("prediction", SupportedMethod.Predictions, run_on_all_pipelines=True),
Startup("saliency", SupportedMethod.Saliency, run_on_all_pipelines=True),
Startup(
"outcome_count",
SupportedModule.Outcome,
Expand Down Expand Up @@ -236,6 +240,8 @@ def startup_tasks(
# TODO We only check pipeline_index=0, but we should check all pipelines.
if postprocessing_editable(task_manager.config, 0):
start_up_tasks += POSTPROCESSING_TASKS
if saliency_available(task_manager.config):
start_up_tasks += SALIENCY_TASKS
if similarity_available(task_manager.config):
start_up_tasks += SIMILARITY_TASKS

Expand Down
Loading

0 comments on commit 1e0c33e

Please sign in to comment.