Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove generic tokenizer and support multiple languages for the word cloud. #388

Merged
merged 7 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions azimuth/modules/base_classes/artifact_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, Dict, Optional

from datasets import DatasetDict
from transformers import AutoTokenizer

from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
Expand Down Expand Up @@ -117,11 +116,6 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):

return self.models_mapping[config_key][pipeline_idx]

def get_tokenizer(self):
gabegma marked this conversation as resolved.
Show resolved Hide resolved
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
return self.tokenizer

def get_metric(self, config, name: str, **kwargs):
hash: Hash = md5_hash({"name": name, **kwargs})
if hash not in self.metrics:
Expand Down
2 changes: 1 addition & 1 deletion azimuth/modules/model_contracts/hf_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:

"""
if self.saliency_layer is None:
return self.empty_saliency_from_batch(batch)
raise ValueError("This method should not be called when saliency_layer is not defined.")
gabegma marked this conversation as resolved.
Show resolved Hide resolved

pipeline = self.get_model()

Expand Down
26 changes: 0 additions & 26 deletions azimuth/modules/model_contracts/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,29 +305,3 @@ def get_postprocessed_output(
postprocessed_output = postprocessed_steps[-1].output
# Preprocessing steps are not supported at the moment for HF pipelines
return model_out_formatted, postprocessed_output, [], postprocessed_steps

def empty_saliency_from_batch(self, batch) -> List[SaliencyResponse]:
"""Return dummy output to not break API consumers.

Args:
batch: Utterances.

Returns:
Saliencies of 0.

"""
records: List[SaliencyResponse] = []
tokenizer = self.artifact_manager.get_tokenizer()
for utterance in batch[self.config.columns.text_input]:
token_ids = tokenizer(utterance)["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [
token for token in tokens if token not in [tokenizer.cls_token, tokenizer.sep_token]
]

json_output = SaliencyResponse(
saliency=[0.0] * len(tokens),
tokens=tokens,
)
records.append(json_output)
return records
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

class TextClassificationNoSaliencyModule(TextClassificationModule):
def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
return self.empty_saliency_from_batch(batch)
raise NotImplementedError
5 changes: 3 additions & 2 deletions azimuth/modules/utilities/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from azimuth.types import ModuleOptions, SupportedMethod, SupportedModelContract
from azimuth.types.validation import ValidationResponse
from azimuth.utils.logs import MultipleExceptions
from azimuth.utils.project import predictions_available, saliency_available
from azimuth.utils.validation import assert_not_none


Expand Down Expand Up @@ -40,7 +41,7 @@ def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore

model = (
exception_gatherer.try_calling_function(self.get_model)
if self.config.pipelines is not None
if predictions_available(self.config)
else None
)
can_load_model = model is not None
Expand All @@ -61,7 +62,7 @@ def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore
exception_gatherer.try_calling_function(self._validate_prediction, batch=batch)
is not None
)
if can_make_prediction:
if can_make_prediction and saliency_available(self.config):
can_make_saliency = (
exception_gatherer.try_calling_function(self._validate_saliency, batch=batch)
is not None
Expand Down
88 changes: 48 additions & 40 deletions azimuth/modules/word_analysis/top_words.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright ServiceNow, Inc. 2021 – 2022
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
# in the root directory of this source tree.
import string
from collections import Counter
from typing import List, Tuple
from typing import Dict, List, Tuple

import numpy as np
import spacy

from azimuth.config import ModelContractConfig
from azimuth.config import AzimuthConfig
from azimuth.modules.base_classes import FilterableModule
from azimuth.modules.task_execution import get_task_result
from azimuth.modules.word_analysis.tokens_to_words import TokensToWordsModule
Expand All @@ -19,21 +19,25 @@
TopWordsResult,
)
from azimuth.utils.dataset_operations import get_predictions_from_ds
from azimuth.utils.ml.third_parties.stop_words import STOP_WORDS
from azimuth.utils.project import saliency_available
from azimuth.utils.utterance import clean_utterance

MIN_SALIENCY = 0.01


class TopWordsModule(FilterableModule[ModelContractConfig]):
"""Returns the most important words in terms of their saliency value or frequency."""
class TopWordsModule(FilterableModule[AzimuthConfig]):
gabegma marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the most important words in terms of their saliency value or frequency.

Note: The config scope is AzimuthConfig because the module relies on both the pipeline or the
syntax config, depending if saliency is available. AzimuthConfig is a bit too broad, but it
should not be a problem since this module computes fast.
gabegma marked this conversation as resolved.
Show resolved Hide resolved
"""

allowed_mod_options = FilterableModule.allowed_mod_options | {
"top_x",
"th_importance",
"force_no_saliency",
}
stop_words_punctuation = list(string.punctuation) + STOP_WORDS

@staticmethod
def count_words(list_of_words: List[str], top_x: int) -> List[TopWordsResult]:
Expand Down Expand Up @@ -102,44 +106,48 @@ def compute_on_dataset_split(self) -> List[TopWordsResponse]: # type: ignore
)
]

# Saliencies will be 0 if saliency maps are not available.
words_saliencies = self.get_words_saliencies(self.get_indices())

important_words_all = []
important_words_right = []
important_words_errors = []
important_words_per_idx: Dict[int, List[str]] = {}
if importance_criteria == TopWordsImportanceCriteria.salient:
words_saliencies = self.get_words_saliencies(self.get_indices())
tokenizer = self.get_model().tokenizer
gabegma marked this conversation as resolved.
Show resolved Hide resolved
for idx, record in enumerate(words_saliencies):
# Put everything to lower case and remove cls/sep tokens.
words, saliencies = zip(
*[
(word.lower(), saliency_value)
for word, saliency_value in zip(record.words, record.saliency)
if word not in [tokenizer.cls_token, tokenizer.sep_token]
]
)
if words:
importance_saliency = max(
self.mod_options.th_importance * max(record.saliency), MIN_SALIENCY
)
important_words_per_idx[idx] = [
word
for word, _ in filter(
lambda s: s[1] > importance_saliency, zip(words, saliencies)
)
]
# If saliency is not available, we proxy important words as any word that is neither
# punctuation or a stop word.
gabegma marked this conversation as resolved.
Show resolved Hide resolved
else:
spacy_model = spacy.load(self.config.syntax.spacy_model)
utterances = ds[self.config.columns.text_input]
for idx, utterance in enumerate(utterances):
doc = spacy_model(clean_utterance(utterance))
important_words_per_idx[idx] = [
token.text for token in doc if not token.is_stop and not token.is_punct
]

tokenizer = self.artifact_manager.get_tokenizer()
is_error = np.array(
get_predictions_from_ds(ds, self.mod_options.without_postprocessing)
) != np.array(ds[self.config.columns.label])

for idx, record in enumerate(words_saliencies):
# Put everything to lower case and remove cls/sep tokens.
words, saliencies = zip(
*[
(word.lower(), saliency_value)
for word, saliency_value in zip(record.words, record.saliency)
if word not in [tokenizer.cls_token, tokenizer.sep_token]
]
)
if importance_criteria == TopWordsImportanceCriteria.salient and words != []:
th_importance = self.mod_options.th_importance
assert len(words) == len(saliencies)
importance_saliency = max(th_importance * max(record.saliency), MIN_SALIENCY)
important_words = [
word
for word, _ in filter(
lambda s: s[1] > importance_saliency, zip(words, saliencies)
)
]
# If saliency is not available, we proxy important words as any word that is neither
# punctuation or a stop word.
else:
important_words = [
word for word in words if word not in self.stop_words_punctuation
]

important_words_all = []
important_words_errors = []
important_words_right = []
for idx, important_words in important_words_per_idx.items():
important_words_all.extend(important_words)
if is_error[idx]:
important_words_errors.extend(important_words)
Expand Down
8 changes: 7 additions & 1 deletion azimuth/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
perturbation_testing_available,
postprocessing_editable,
predictions_available,
saliency_available,
similarity_available,
)
from azimuth.utils.validation import assert_not_none
Expand Down Expand Up @@ -101,9 +102,12 @@ class Startup:
)
]

SALIENCY_TASKS = [
Startup("saliency", SupportedMethod.Saliency, run_on_all_pipelines=True),
]

BASE_PREDICTION_TASKS = [
Startup("prediction", SupportedMethod.Predictions, run_on_all_pipelines=True),
Startup("saliency", SupportedMethod.Saliency, run_on_all_pipelines=True),
Startup(
"outcome_count",
SupportedModule.Outcome,
Expand Down Expand Up @@ -236,6 +240,8 @@ def startup_tasks(
# TODO We only check pipeline_index=0, but we should check all pipelines.
if postprocessing_editable(task_manager.config, 0):
start_up_tasks += POSTPROCESSING_TASKS
if saliency_available(task_manager.config):
start_up_tasks += SALIENCY_TASKS
if similarity_available(task_manager.config):
start_up_tasks += SIMILARITY_TASKS

Expand Down
Loading