-
Notifications
You must be signed in to change notification settings - Fork 309
/
Copy pathtest_all_abstasks.py
72 lines (57 loc) · 2.35 KB
/
test_all_abstasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations
import asyncio
import logging
from unittest.mock import Mock, patch
import aiohttp
import pytest
from mteb import MTEB
from mteb.abstasks import AbsTask
from mteb.abstasks.AbsTaskInstructionRetrieval import AbsTaskInstructionRetrieval
from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval
logging.basicConfig(level=logging.INFO)
@pytest.mark.parametrize("task", MTEB().tasks_cls)
@patch("datasets.load_dataset")
def test_load_data(mock_load_dataset: Mock, task: AbsTask):
# TODO: We skip because this load_data is completely different.
if isinstance(task, AbsTaskRetrieval) or isinstance(
task, AbsTaskInstructionRetrieval
):
pytest.skip()
with patch.object(task, "dataset_transform") as mock_dataset_transform:
task.load_data()
mock_load_dataset.assert_called()
# They don't yet but should they so they can be expanded more easily?
if not task.is_crosslingual and not task.is_multilingual:
mock_dataset_transform.assert_called_once()
async def check_dataset_on_hf(
session: aiohttp.ClientSession, dataset: str, revision: str
) -> bool:
url = f"https://huggingface.co/datasets/{dataset}/tree/{revision}"
async with session.head(url) as response:
return response.status == 200
async def check_datasets_are_available_on_hf(tasks):
does_not_exist = []
async with aiohttp.ClientSession() as session:
tasks_checks = [
check_dataset_on_hf(
session,
task.metadata.dataset["path"],
task.metadata.dataset["revision"],
)
for task in tasks
]
datasets_exists = await asyncio.gather(*tasks_checks)
for task, ds_exists in zip(tasks, datasets_exists):
if not ds_exists:
does_not_exist.append(
(task.metadata.dataset["path"], task.metadata.dataset["revision"])
)
if does_not_exist:
pretty_print = "\n".join(
[f"{ds[0]} - revision {ds[1]}" for ds in does_not_exist]
)
assert False, f"Datasets not available on Hugging Face:\n{pretty_print}"
def test_dataset_availability():
"""Checks if the datasets are available on Hugging Face using both their name and revision."""
tasks = MTEB().tasks_cls
asyncio.run(check_datasets_are_available_on_hf(tasks))