Skip to content

Commit

Permalink
Add bleu metric (#1834)
Browse files Browse the repository at this point in the history
* add bleu metric - refactor rouge and add nlp module

* Remove blank line

* autopep8 fix

* Remove metrics

* Add nlp metrics in nlp.__init__

* autopep8 fix

* update

* expose nlp - replace assert by exception

* fix F401

* fix doc

* autopep8 fix

* add test

* Resolve conflict

Co-authored-by: Desroziers <[email protected]>
Co-authored-by: sdesrozis <[email protected]>
Co-authored-by: vfdev <[email protected]>
  • Loading branch information
4 people authored Mar 23, 2021
1 parent 5d4fa3e commit f584e8a
Show file tree
Hide file tree
Showing 10 changed files with 674 additions and 127 deletions.
7 changes: 4 additions & 3 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,14 @@ Complete list of metrics
precision.Precision
PSNR
recall.Recall
Rouge
rouge.RougeL
rouge.RougeN
RootMeanSquaredError
RunningAverage
SSIM
TopKCategoricalAccuracy
Bleu
Rouge
RougeL
RougeN

Helpers for customizing metrics
-------------------------------
Expand Down
10 changes: 6 additions & 4 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.precision import Precision
from ignite.metrics.psnr import PSNR
from ignite.metrics.recall import Recall
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.rouge import Rouge, RougeL, RougeN
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
Expand Down Expand Up @@ -43,11 +44,12 @@
"PSNR",
"Recall",
"RootMeanSquaredError",
"Rouge",
"RougeN",
"RougeL",
"RunningAverage",
"VariableAccumulation",
"Frequency",
"SSIM",
"Bleu",
"Rouge",
"RougeN",
"RougeL",
]
9 changes: 9 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Rouge",
"RougeN",
"RougeL",
]
191 changes: 191 additions & 0 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import math
from collections import Counter
from typing import Any, Callable, Sequence, Tuple, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.nlp.utils import modified_precision

__all__ = ["Bleu"]


def _closest_ref_length(references: Sequence[Sequence[Any]], hyp_len: int) -> int:
ref_lens = (len(reference) for reference in references)
closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len))
return closest_ref_len


class _Smoother:
"""
Smoothing helper
http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
"""

def __init__(self, method: str):
valid = ["no_smooth", "smooth1", "nltk_smooth2", "smooth2"]
if method not in valid:
raise ValueError(f"Smooth is not valid (expected: {valid}, got: {method})")
self.smooth = method

def __call__(self, numerators: Counter, denominators: Counter) -> Sequence[float]:
method = getattr(self, self.smooth)
return method(numerators, denominators)

@staticmethod
def smooth1(numerators: Counter, denominators: Counter) -> Sequence[float]:
epsilon = 0.1
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d if n != 0 else epsilon / d for n, d in zip(numerators.values(), denominators_)]

@staticmethod
def nltk_smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return [(n + 1) / (d + 1) for n, d in zip(numerators.values(), denominators_)]

@staticmethod
def smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
return [(n + 1) / (d + 1) for n, d in zip(numerators.values(), denominators.values())]

@staticmethod
def no_smooth(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d for n, d in zip(numerators.values(), denominators_)]


class Bleu(Metric):
r"""Calculates the `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_.
.. math::
\text{BLEU} = b_{p} \cdot \exp \left( \sum_{n=1}^{N} w_{n} \: \log p_{n} \right)
where :math:`N` is the order of n-grams, :math:`b_{p}` is a sentence brevety penalty, :math:`w_{n}` are
positive weights summing to one and :math:`p_{n}` are modified n-gram precisions.
More details can be found in `Papineni et al. 2002`__.
__ https://www.aclweb.org/anthology/P02-1040.pdf
In addition, a review of smoothing techniques can be found in `Chen et al. 2014`__
__ http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
Remark :
This implementation is inspired by nltk
Args:
ngram: order of n-grams.
smooth: enable smoothing. Valid are ``no_smooth``, ``smooth1``, ``nltk_smooth2`` or ``smooth2``.
Default: ``no_smooth``.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Example:
.. code-block:: python
from ignite.metrics.nlp import Bleu
m = Bleu(ngram=4, smooth="smooth1")
y_pred = "the the the the the the the"
y = ["the cat is on the mat", "there is a cat on the mat"]
m.update((y_pred.split(), [y.split()]))
print(m.compute())
.. versionadded:: 0.5.0
"""

def __init__(
self,
ngram: int = 4,
smooth: str = "no_smooth",
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if ngram <= 0:
raise ValueError(f"ngram order must be greater than zero (got: {ngram})")
self.ngrams_order = ngram
self.weights = [1 / self.ngrams_order] * self.ngrams_order
self.smoother = _Smoother(method=smooth)
super(Bleu, self).__init__(output_transform=output_transform, device=device)

def _corpus_bleu(self, references: Sequence[Sequence[Any]], candidates: Sequence[Sequence[Any]],) -> float:
p_numerators: Counter = Counter()
p_denominators: Counter = Counter()

if len(references) != len(candidates):
raise ValueError(
f"nb of candidates should be equal to nb of reference lists ({len(candidates)} != "
f"{len(references)})"
)

# Iterate through each hypothesis and their corresponding references.
for refs, hyp in zip(references, candidates):
# For each order of ngram, calculate the numerator and
# denominator for the corpus-level modified precision.
for i in range(1, self.ngrams_order + 1):
numerator, denominator = modified_precision(refs, hyp, i)
p_numerators[i] += numerator
p_denominators[i] += denominator

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
# no unigrams, there won't be any higher order ngrams.
if p_numerators[1] == 0:
return 0

# If no smoother, returns 0 if there's at least one a not matching n-grams
if self.smoother.smooth == "no_smooth" and min(p_numerators.values()) == 0:
return 0

# Calculate the hypothesis lengths
hyp_lengths = [len(hyp) for hyp in candidates]

# Calculate the closest reference lengths.
ref_lengths = [_closest_ref_length(refs, hyp_len) for refs, hyp_len in zip(references, hyp_lengths)]

# Sum of hypothesis and references lengths
hyp_len = sum(hyp_lengths)
ref_len = sum(ref_lengths)

# Calculate corpus-level brevity penalty.
if hyp_len < ref_len:
bp = math.exp(1 - ref_len / hyp_len) if hyp_len > 0 else 0.0
else:
bp = 1.0

# Smoothing
p_n = self.smoother(p_numerators, p_denominators)

# Compute the geometric mean
s = [w_i * math.log(p_i) for w_i, p_i in zip(self.weights, p_n)]
gm = bp * math.exp(math.fsum(s))
return gm

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_sentences = 0

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Any], Sequence[Sequence[Any]]]) -> None:
y_pred, y = output
self._sum_of_bleu += self._corpus_bleu(references=[y], candidates=[y_pred])
self._num_sentences += 1

@sync_all_reduce("_sum_of_bleu", "_num_sentences")
def compute(self) -> torch.Tensor:
if self._num_sentences == 0:
raise NotComputableError("Bleu must have at least one example before it can be computed.")
return self._sum_of_bleu / self._num_sentences
52 changes: 4 additions & 48 deletions ignite/metrics/rouge.py → ignite/metrics/nlp/rouge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from collections import Counter, namedtuple
from collections import namedtuple
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -9,53 +9,9 @@

# These decorators helps with distributed settings
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.metrics.nlp.utils import lcs, ngrams


def ngrams(sequence: Sequence[Any], n: int) -> Counter:
"""
Generate the ngrams from a sequence of items
Args:
sequence: sequence of items
n: ngram order
Returns:
A counter of ngram objects
.. versionadded:: 0.5.0
"""
return Counter([tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)])


def lcs(seq_a: Sequence[Any], seq_b: Sequence[Any]) -> int:
"""
Compute the length of the longest common subsequence in two sequence of items
https://en.wikipedia.org/wiki/Longest_common_subsequence_problem
Args:
seq_a: first sequence of items
seq_b: second sequence of items
Returns:
The length of the longest common subsequence
.. versionadded:: 0.5.0
"""
m = len(seq_a)
n = len(seq_b)

dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
dp[i][j] = 0
elif seq_a[i - 1] == seq_b[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

return dp[m][n]
__all__ = ["Rouge", "RougeN", "RougeL"]


class Score(namedtuple("Score", ["match", "candidate", "reference"])):
Expand Down Expand Up @@ -286,7 +242,7 @@ def __init__(
super(RougeN, self).__init__(multiref=multiref, alpha=alpha, output_transform=output_transform, device=device)
self._ngram = ngram
if self._ngram < 1:
raise ValueError(f"ngram order must be greater than one (got : {self._ngram})")
raise ValueError(f"ngram order must be greater than zero (got : {self._ngram})")

def _compute_score(self, candidate: Sequence[Any], reference: Sequence[Any]) -> Score:
return compute_ngram_scores(candidate=candidate, reference=reference, n=self._ngram)
Expand Down
Loading

0 comments on commit f584e8a

Please sign in to comment.