Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast loading for cross lingual tasks #572

Merged
merged 15 commits into from
Apr 30, 2024
3 changes: 3 additions & 0 deletions docs/mmteb/points/572.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"GitHub": "loicmagne", "Bug fixes": 8}
{"GitHub": "KennethEnevoldsen", "Review PR": 2}
{"GitHub": "imenelydiaker", "Review PR": 2}
16 changes: 2 additions & 14 deletions mteb/abstasks/CrosslingualTask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import datasets

from .AbsTask import AbsTask
from .MultiSubsetLoader import MultiSubsetLoader


class CrosslingualTask(AbsTask):
class CrosslingualTask(MultiSubsetLoader, AbsTask):
def __init__(self, langs=None, **kwargs):
super().__init__(**kwargs)
if isinstance(langs, list):
Expand All @@ -15,14 +14,3 @@ def __init__(self, langs=None, **kwargs):
else:
self.langs = self.metadata_dict["eval_langs"]
self.is_crosslingual = True

def load_data(self, **kwargs):
"""Load dataset from HuggingFace hub"""
if self.data_loaded:
return
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
name=lang, **self.metadata_dict["dataset"]
)
self.data_loaded = True
46 changes: 46 additions & 0 deletions mteb/abstasks/MultiSubsetLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

import datasets


class MultiSubsetLoader:
def load_data(self, **kwargs):
"""Load dataset containing multiple subsets from HuggingFace hub"""
if self.data_loaded:
return

fast_loading = self.fast_loading if hasattr(self, "fast_loading") else False
if fast_loading:
self.fast_load()
else:
self.slow_load()

self.dataset_transform()
self.data_loaded = True

def fast_load(self, **kwargs):
"""Load all subsets at once, then group by language with Polars. Using fast loading has two requirements:
- Each row in the dataset should have a 'lang' feature giving the corresponding language/language pair
- The datasets must have a 'default' config that loads all the subsets of the dataset (see https://huggingface.co/docs/datasets/en/repository_structure#configurations)
"""
self.dataset = {}
merged_dataset = datasets.load_dataset(
**self.metadata_dict["dataset"]
) # load "default" subset
for split in self.metadata.eval_splits:
grouped_by_lang = dict(merged_dataset[split].to_polars().group_by("lang"))
for lang in self.langs:
if lang not in self.dataset:
self.dataset[lang] = {}
self.dataset[lang][split] = datasets.Dataset.from_polars(
grouped_by_lang[lang].drop("lang")
) # Remove lang column and convert back to HF datasets, not strictly necessary but better for compatibility

def slow_load(self, **kwargs):
"""Load each subsets iteratively"""
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
name=lang,
**self.metadata_dict.get("dataset", None),
)
18 changes: 2 additions & 16 deletions mteb/abstasks/MultilingualTask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import datasets

from .AbsTask import AbsTask
from .MultiSubsetLoader import MultiSubsetLoader


class MultilingualTask(AbsTask):
class MultilingualTask(MultiSubsetLoader, AbsTask):
def __init__(self, langs=None, **kwargs):
super().__init__(**kwargs)
if isinstance(langs, list):
Expand All @@ -17,16 +16,3 @@ def __init__(self, langs=None, **kwargs):
else:
self.langs = self.metadata_dict["eval_langs"]
self.is_multilingual = True

def load_data(self, **kwargs):
"""Load dataset from HuggingFace hub"""
if self.data_loaded:
return
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
name=lang,
**self.metadata_dict.get("dataset", None),
)
self.dataset_transform()
self.data_loaded = True
3 changes: 2 additions & 1 deletion mteb/tasks/BitextMining/multilingual/TatoebaBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@


class TatoebaBitextMining(AbsTaskBitextMining, CrosslingualTask):
fast_loading = True
metadata = TaskMetadata(
name="Tatoeba",
dataset={
"path": "mteb/tatoeba-bitext-mining",
"revision": "9080400076fbadbb4c4dcb136ff4eddc40b42553",
"revision": "69e8f12da6e31d59addadda9a9c8a2e601a0e282",
},
description="1,000 English-aligned sentence pairs for each language based on the Tatoeba corpus",
reference="https://github.com/facebookresearch/LASER/tree/main/data/tatoeba/v1",
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
requires-python = ">=3.8"
dependencies = [
"datasets>=2.2.0",
"datasets>=2.19.0",
"jsonlines",
"numpy",
"requests>=2.26.0",
Expand All @@ -39,6 +39,7 @@ dependencies = [
"pydantic>=2.0.0",
"typing_extensions",
"eval_type_backport",
"polars>=0.20.22",
]


Expand Down
Loading
Loading