Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bleu metric #1834

Merged
merged 22 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ignite.metrics.psnr import PSNR
from ignite.metrics.recall import Recall
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.rouge import Rouge, RougeL, RougeN
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
Expand Down Expand Up @@ -43,9 +42,6 @@
"PSNR",
"Recall",
"RootMeanSquaredError",
"Rouge",
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
"RougeN",
"RougeL",
"RunningAverage",
"VariableAccumulation",
"Frequency",
Expand Down
9 changes: 9 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Rouge",
"RougeN",
"RougeL",
]
182 changes: 182 additions & 0 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import math
from collections import Counter
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.nlp.utils import modified_precision

__all__ = ["Bleu"]


def _closest_ref_length(references: Sequence[Sequence[Any]], hyp_len: int) -> int:
ref_lens = (len(reference) for reference in references)
closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len))
return closest_ref_len


class _Smoother:
"""
Smoothing helper
http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
"""

def __init__(self, method: str):
valid = ["no_smooth", "smooth1", "nltk_smooth2", "smooth2"]
if method not in valid:
raise ValueError(f"Smooth is not valid (expected: {valid}, got: {method})")
self.smooth = method

def __call__(self, numerators: Counter, denominators: Counter) -> Sequence[float]:
method = getattr(self, self.smooth)
return method(numerators, denominators)

@staticmethod
def smooth1(numerators: Counter, denominators: Counter) -> Sequence[float]:
epsilon = 0.1
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d if n != 0 else epsilon / d for n, d in zip(numerators.values(), denominators_)]

@staticmethod
def nltk_smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return [(n + 1) / (d + 1) for n, d in zip(numerators.values(), denominators_)]

@staticmethod
def smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
return [(n + 1) / (d + 1) for n, d in zip(numerators.values(), denominators.values())]

@staticmethod
def no_smooth(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d for n, d in zip(numerators.values(), denominators_)]


class Bleu(Metric):
r"""Calculates the `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_.

.. math::
\text{BLEU} = \text{BP} \dot exp \left( \sum_{n=1}^{N} w_{n} log p_{n} \right)

where :math:`N` is the order of n-grams, :math:`\text{BP}` is a sentence brevety penalty, :math:`w_{n}` are
positive weights summing to one and :math:`p_{n}` are modified n-gram precisions.

More details can be found in `Papineni et al. 2002`__.

__ https://www.aclweb.org/anthology/P02-1040.pdf

In addition, a review of smoothing techniques can be found in `Chen et al. 2014`__

__ http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf

Remark :

This implementation is inspired by nltk

Args:
ngram: order of n-grams.
smooth: enable smoothing. Valid are "no_smooth", "smooth1", "nltk_smooth2" or "smooth2". (Default: "no_smooth")
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Example:

.. code-block:: python

from ignite.metrics import Bleu
m = Bleu(ngram=4, smooth="smooth1")
y_pred = "the the the the the the the"
y = ["the cat is on the mat", "there is a cat on the mat"]
m.update((y_pred.split(), [y.split()]))
print(m.compute())

.. versionadded:: 0.5.0
"""

def __init__(
self,
ngram: int = 4,
smooth: str = "no_smooth",
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if ngram <= 0:
raise ValueError(f"ngram order must be greater than zero (got: {ngram})")
self.ngrams_order = ngram
self.weights = [1 / self.ngrams_order] * self.ngrams_order
self.smoother = _Smoother(method=smooth)
super(Bleu, self).__init__(output_transform=output_transform, device=device)

def corpus_bleu(self, references: Sequence[Sequence[Any]], candidates: Sequence[Sequence[Any]],) -> float:
p_numerators: Counter = Counter() # Key = ngram order, and value = no. of ngram matches.
p_denominators: Counter = Counter() # Key = ngram order, and value = no. of ngram in ref.

assert len(references) == len(candidates), "The number of hypotheses and their reference(s) should be the same "
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved

# Iterate through each hypothesis and their corresponding references.
for refs, hyp in zip(references, candidates):
# For each order of ngram, calculate the numerator and
# denominator for the corpus-level modified precision.
for i in range(1, self.ngrams_order + 1):
numerator, denominator = modified_precision(refs, hyp, i)
p_numerators[i] += numerator
p_denominators[i] += denominator

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
# no unigrams, there won't be any higher order ngrams.
if p_numerators[1] == 0:
return 0
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved

# If no smoother, returns 0 if there's at least one a not matching n-grams
if self.smoother.smooth == "no_smooth" and min(p_numerators.values()) == 0:
return 0

# Calculate the hypothesis lengths
hyp_lengths = [len(hyp) for hyp in candidates]

# Calculate the closest reference lengths.
ref_lengths = [_closest_ref_length(refs, hyp_len) for refs, hyp_len in zip(references, hyp_lengths)]

# Sum of hypothesis and references lengths
hyp_len = sum(hyp_lengths)
ref_len = sum(ref_lengths)

# Calculate corpus-level brevity penalty.
if hyp_len < ref_len:
bp = math.exp(1 - ref_len / hyp_len) if hyp_len > 0 else 0.0
else:
bp = 1.0

# Smoothing
p_n = self.smoother(p_numerators, p_denominators)

# Compute the geometric mean
s = [w_i * math.log(p_i) for w_i, p_i in zip(self.weights, p_n)]
gm = bp * math.exp(math.fsum(s))
return gm

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_sentences = 0

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Any], Sequence[Sequence[Any]]]) -> None:
y_pred, y = output
self._sum_of_bleu += self.corpus_bleu(references=[y], candidates=[y_pred])
self._num_sentences += 1

@sync_all_reduce("_sum_of_bleu", "_num_sentences")
def compute(self) -> torch.Tensor:
if self._num_sentences == 0:
raise NotComputableError("Bleu must have at least one example before it can be computed.")
return self._sum_of_bleu / self._num_sentences
52 changes: 4 additions & 48 deletions ignite/metrics/rouge.py → ignite/metrics/nlp/rouge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from collections import Counter, namedtuple
from collections import namedtuple
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -9,53 +9,9 @@

# These decorators helps with distributed settings
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.metrics.nlp.utils import lcs, ngrams


def ngrams(sequence: Sequence[Any], n: int) -> Counter:
"""
Generate the ngrams from a sequence of items

Args:
sequence: sequence of items
n: ngram order

Returns:
A counter of ngram objects

.. versionadded:: 0.5.0
"""
return Counter([tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)])


def lcs(seq_a: Sequence[Any], seq_b: Sequence[Any]) -> int:
"""
Compute the length of the longest common subsequence in two sequence of items
https://en.wikipedia.org/wiki/Longest_common_subsequence_problem

Args:
seq_a: first sequence of items
seq_b: second sequence of items

Returns:
The length of the longest common subsequence

.. versionadded:: 0.5.0
"""
m = len(seq_a)
n = len(seq_b)

dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
dp[i][j] = 0
elif seq_a[i - 1] == seq_b[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

return dp[m][n]
__all__ = ["Rouge", "RougeN", "RougeL"]


class Score(namedtuple("Score", ["match", "candidate", "reference"])):
Expand Down Expand Up @@ -289,7 +245,7 @@ def __init__(
super(RougeN, self).__init__(multiref=multiref, alpha=alpha, output_transform=output_transform, device=device)
self._ngram = ngram
if self._ngram < 1:
raise ValueError(f"ngram order must be greater than one (got : {self._ngram})")
raise ValueError(f"ngram order must be greater than zero (got : {self._ngram})")

def _compute_score(self, candidate: Sequence[Any], reference: Sequence[Any]) -> Score:
return compute_ngram_scores(candidate=candidate, reference=reference, n=self._ngram)
Expand Down
89 changes: 89 additions & 0 deletions ignite/metrics/nlp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from collections import Counter
from typing import Any, Sequence, Tuple

__all__ = ["ngrams", "lcs", "modified_precision"]


def ngrams(sequence: Sequence[Any], n: int) -> Counter:
"""
Generate the ngrams from a sequence of items

Args:
sequence: sequence of items
n: n-gram order

Returns:
A counter of ngram objects

.. versionadded:: 0.5.0
"""
return Counter([tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)])


def lcs(seq_a: Sequence[Any], seq_b: Sequence[Any]) -> int:
"""
Compute the length of the longest common subsequence in two sequence of items
https://en.wikipedia.org/wiki/Longest_common_subsequence_problem

Args:
seq_a: first sequence of items
seq_b: second sequence of items

Returns:
The length of the longest common subsequence

.. versionadded:: 0.5.0
"""
m = len(seq_a)
n = len(seq_b)

dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
dp[i][j] = 0
elif seq_a[i - 1] == seq_b[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

return dp[m][n]


def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: int) -> Tuple[int, int]:
"""
Compute the modified precision

.. math::
p_{n} = \frac{m_{n}}{l_{n}}

where m_{n} is the number of matched n-grams between translation T and its reference R, and l_{n} is the
total number of n-grams in the translation T.

More details can be found in `Papineni et al. 2002`__.

__ https://www.aclweb.org/anthology/P02-1040.pdf

Args:
references: list of references R
candidate: translation T
n: n-gram order

Returns:
The length of the longest common subsequence

.. versionadded:: 0.5.0
"""
# ngrams of the candidate
counts = ngrams(candidate, n)

# union of ngrams of references
max_counts: Counter = Counter()
for reference in references:
max_counts |= ngrams(reference, n)

# clipped count of the candidate and references
clipped_counts = counts & max_counts

return sum(clipped_counts.values()), sum(counts.values())
Loading