Skip to content

Commit

Permalink
Dataset API cleanup (#1106)
Browse files Browse the repository at this point in the history
pluralize dataset_names for dataset api, add checks and allow strings
  • Loading branch information
joeylegere authored Mar 2, 2023
1 parent fcf5b81 commit ff4af76
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 25 deletions.
2 changes: 1 addition & 1 deletion bittensor/_config/config_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def to_defaults(self):
bittensor.defaults.dataset.block_size = self.dataset.block_size
bittensor.defaults.dataset.num_batches = self.dataset.num_batches
bittensor.defaults.dataset.num_workers = self.dataset.num_workers
bittensor.defaults.dataset.dataset_name = self.dataset.dataset_name
bittensor.defaults.dataset.dataset_names = self.dataset.dataset_names
bittensor.defaults.dataset.data_dir = self.dataset.data_dir
bittensor.defaults.dataset.save_dataset = self.dataset.save_dataset
bittensor.defaults.dataset.max_datasets = self.dataset.max_datasets
Expand Down
28 changes: 18 additions & 10 deletions bittensor/_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import argparse
import os
import copy
from typing import Union
import warnings

import bittensor
from . import dataset_impl
Expand All @@ -43,11 +45,12 @@ def __new__(
block_size: int = None,
batch_size: int = None,
num_workers: int = None,
dataset_name: list = [],
dataset_names: Union[list, str] = None,
save_dataset: bool=None,
no_tokenizer: bool=None,
num_batches: int = None,
_mock:bool=None
_mock:bool=None,
dataset_name: list = None, # For backwards compatibility
):
r""" Create and init the GenesisTextDataset class, which handles dataloading from ipfs.
Args:
Expand All @@ -59,7 +62,7 @@ def __new__(
Batch size.
num_workers (:obj:`int`, `optional`):
Number of workers for data loader.
dataset_name (:obj:`list`, `optional`):
dataset_names (:obj:`list`,`str`, `optional`):
Which datasets to use (ArXiv, BookCorpus2, Books3, DMMathematics, EnronEmails, EuroParl,
Gutenberg_PG, HackerNews, NIHExPorter, OpenSubtitles, PhilPapers, UbuntuIRC, YoutubeSubtitles)).
save_dataset (:obj:`bool`, `optional`):
Expand All @@ -77,18 +80,23 @@ def __new__(
config.dataset.block_size = block_size if block_size != None else config.dataset.block_size
config.dataset.batch_size = batch_size if batch_size != None else config.dataset.batch_size
config.dataset.num_workers = num_workers if num_workers != None else config.dataset.num_workers
config.dataset.dataset_name = dataset_name if dataset_name != [] else config.dataset.dataset_name
config.dataset.dataset_names = dataset_names if dataset_names != None else config.dataset.dataset_names
config.dataset.save_dataset = save_dataset if save_dataset != None else config.dataset.save_dataset
config.dataset.no_tokenizer = no_tokenizer if no_tokenizer != None else config.dataset.no_tokenizer
config.dataset.num_batches = num_batches if num_batches != None else config.dataset.num_batches
config.dataset._mock = _mock if _mock != None else config.dataset._mock
dataset.check_config( config )

if dataset_name is not None:
warnings.warn("dataset_name as a parameter is deprecated and will be removed in a future release. Use `dataset_names` instead.", DeprecationWarning)
config.dataset.dataset_names = dataset_name

if config.dataset._mock:
return dataset_mock.MockGenesisTextDataset(
block_size = config.dataset.block_size,
batch_size = config.dataset.batch_size,
num_workers = config.dataset.num_workers,
dataset_name = config.dataset.dataset_name,
dataset_names = config.dataset.dataset_names,
data_dir = config.dataset.data_dir,
save_dataset = config.dataset.save_dataset,
max_datasets = config.dataset.max_datasets,
Expand All @@ -100,7 +108,7 @@ def __new__(
block_size = config.dataset.block_size,
batch_size = config.dataset.batch_size,
num_workers = config.dataset.num_workers,
dataset_name = config.dataset.dataset_name,
dataset_names = config.dataset.dataset_names,
data_dir = config.dataset.data_dir,
save_dataset = config.dataset.save_dataset,
max_datasets = config.dataset.max_datasets,
Expand All @@ -110,7 +118,7 @@ def __new__(

@classmethod
def mock(cls):
return dataset( _mock = True, dataset_name = ['Books3'])
return dataset( _mock = True, dataset_names = ['Books3'])

@classmethod
def config(cls) -> 'bittensor.Config':
Expand All @@ -130,8 +138,8 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None ):
parser.add_argument('--' + prefix_str + 'dataset.batch_size', type=int, help='Batch size.', default = bittensor.defaults.dataset.batch_size)
parser.add_argument('--' + prefix_str + 'dataset.block_size', type=int, help='Number of text items to pull for each example..', default = bittensor.defaults.dataset.block_size)
parser.add_argument('--' + prefix_str + 'dataset.num_workers', type=int, help='Number of workers for data loader.', default = bittensor.defaults.dataset.num_workers)
parser.add_argument('--' + prefix_str + 'dataset.dataset_name', type=str, required=False, nargs='*', action='store', help='Which datasets to use (ArXiv, BookCorpus2, Books3, DMMathematics, EnronEmails, EuroParl, Gutenberg_PG, HackerNews, NIHExPorter, OpenSubtitles, PhilPapers, UbuntuIRC, YoutubeSubtitles)).',
default = bittensor.defaults.dataset.dataset_name)
parser.add_argument('--' + prefix_str + 'dataset.dataset_names', type=str, required=False, nargs='*', action='store', help='Which datasets to use (ArXiv, BookCorpus2, Books3, DMMathematics, EnronEmails, EuroParl, Gutenberg_PG, HackerNews, NIHExPorter, OpenSubtitles, PhilPapers, UbuntuIRC, YoutubeSubtitles)).',
default = bittensor.defaults.dataset.dataset_names)
parser.add_argument('--' + prefix_str + 'dataset.data_dir', type=str, help='Where to save and load the data.', default = bittensor.defaults.dataset.data_dir)
parser.add_argument('--' + prefix_str + 'dataset.save_dataset', action='store_true', help='Save the downloaded dataset or not.', default = bittensor.defaults.dataset.save_dataset)
parser.add_argument('--' + prefix_str + 'dataset.max_datasets', type=int, help='Number of datasets to load', default = bittensor.defaults.dataset.max_datasets)
Expand Down Expand Up @@ -160,7 +168,7 @@ def add_defaults(cls, defaults):
defaults.dataset.batch_size = os.getenv('BT_DATASET_BATCH_SIZE') if os.getenv('BT_DATASET_BATCH_SIZE') != None else 10
defaults.dataset.block_size = os.getenv('BT_DATASET_BLOCK_SIZE') if os.getenv('BT_DATASET_BLOCK_SIZE') != None else 20
defaults.dataset.num_workers = os.getenv('BT_DATASET_NUM_WORKERS') if os.getenv('BT_DATASET_NUM_WORKERS') != None else 0
defaults.dataset.dataset_name = os.getenv('BT_DATASET_DATASET_NAME') if os.getenv('BT_DATASET_DATASET_NAME') != None else 'default'
defaults.dataset.dataset_names = os.getenv('BT_DATASET_DATASET_NAME') if os.getenv('BT_DATASET_DATASET_NAME') != None else 'default'
defaults.dataset.data_dir = os.getenv('BT_DATASET_DATADIR') if os.getenv('BT_DATASET_DATADIR') != None else '~/.bittensor/data/'
defaults.dataset.save_dataset = os.getenv('BT_DATASET_SAVE_DATASET') if os.getenv('BT_DATASET_SAVE_DATASET') != None else False
defaults.dataset.max_datasets = os.getenv('BT_DATASET_MAX_DATASETS') if os.getenv('BT_DATASET_MAX_DATASETS') != None else 3
Expand Down
23 changes: 17 additions & 6 deletions bittensor/_dataset/dataset_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import random
import time
import warnings
from multiprocessing import cpu_count
from typing import Union

Expand Down Expand Up @@ -129,7 +130,7 @@ def __init__(
block_size,
batch_size,
num_workers,
dataset_name,
dataset_names,
data_dir,
save_dataset,
max_datasets,
Expand All @@ -141,7 +142,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers
self.tokenizer = bittensor.tokenizer( version = bittensor.__version__ )
self.dataset_name = dataset_name
self.dataset_names = dataset_names
self.data_dir = data_dir
self.save_dataset = save_dataset
self.datafile_size_bound = 262158
Expand All @@ -153,6 +154,16 @@ def __init__(
self.IPFS_fails_max = 10
self.num_batches = num_batches

# Ensure dataset_names is formatted correctly
if isinstance(self.dataset_names, str):
self.dataset_names = [self.dataset_names]

allowed_datasets = bittensor.__datasets__ + ["default"]
for dataset_name in self.dataset_names:
if dataset_name not in allowed_datasets:
self.dataset_names.remove(dataset_name)
warnings.warn(f"Requested dataset {dataset_name} not in allowed datasets: {allowed_datasets}")

# Retrieve a random slice of the genesis dataset
self.data = []
self.data_reserved = []
Expand Down Expand Up @@ -342,7 +353,7 @@ def get_hashes(dataset_meta):
directories = []
self.IPFS_fails = 0

if self.dataset_name == 'default':
if self.dataset_names == ['default']:
i = 0
dataset_hashes = list(self.dataset_hashes.values())
random.shuffle(dataset_hashes)
Expand All @@ -355,7 +366,7 @@ def get_hashes(dataset_meta):
break

else:
for key in self.dataset_name:
for key in self.dataset_names:
if key in self.dataset_hashes.keys():
dataset_meta = {'Folder': 'mountain','Name': key, 'Hash': self.dataset_hashes[key]['Hash'] }
directories += get_hashes(dataset_meta)
Expand Down Expand Up @@ -413,13 +424,13 @@ def get_root_text_hash(self, file_meta):
def get_text_from_local(self, min_data_len):

folders = os.listdir( os.path.expanduser (self.data_dir))
if self.dataset_name == 'default':
if self.dataset_names == ['default']:
folders_avail = folders
random.shuffle(folders_avail)
folders_avail = folders_avail[:self.max_datasets]
else:
folders_avail = []
for dataset_name in self.dataset_name:
for dataset_name in self.dataset_names:
if dataset_name in folders:
folders_avail.append(dataset_name)
random.shuffle(folders_avail)
Expand Down
4 changes: 2 additions & 2 deletions bittensor/_dataset/dataset_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
block_size,
batch_size,
num_workers,
dataset_name,
dataset_names,
data_dir,
save_dataset,
max_datasets,
Expand All @@ -45,7 +45,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers
self.tokenizer = bittensor.tokenizer( version = bittensor.__version__ )
self.dataset_name = dataset_name
self.dataset_names = dataset_names
self.data_dir = data_dir
self.save_dataset = save_dataset
self.datafile_size_bound = 262158
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

dataset = Munch().fromDict(
{
'dataset_name': ["Books3"],
'dataset_names': ["Books3"],
'num_batches': 10
}
)
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
logging = bittensor.logging()

def test_construct_text_corpus():
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, save_dataset = True, dataset_name = constant.dataset.dataset_name)
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, save_dataset = True, dataset_names = constant.dataset.dataset_names)
dataset.construct_text_corpus()
dataset.close()

def test_next():
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_name = constant.dataset.dataset_name)
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_names = constant.dataset.dataset_names)
next(dataset)
next(dataset)
next(dataset)
dataset.close()

def test_mock():
dataset = bittensor.dataset(_mock=True, dataset_name = constant.dataset.dataset_name)
dataset = bittensor.dataset(_mock=True, dataset_names = constant.dataset.dataset_names)
next(dataset)
next(dataset)
next(dataset)
Expand All @@ -47,7 +47,7 @@ def test_mock_function():
dataset.close()

def test_fail_IPFS_server():
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_name = constant.dataset.dataset_name)
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_names = constant.dataset.dataset_names)
dataset.requests_retry_session = MagicMock(return_value = None)
next(dataset)
next(dataset)
Expand All @@ -57,7 +57,7 @@ def test_fail_IPFS_server():
def test_change_data_size():
data_sizes = [(10,20), (15.5, 20.5),(30, 40), (25,35)]
result_data_sizes = [(10,20), (10,20),(30, 40), (25,35)]
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_name = constant.dataset.dataset_name)
dataset = bittensor.dataset(num_batches = constant.dataset.num_batches, dataset_names = constant.dataset.dataset_names)
for data_size, result_data_size in zip(data_sizes, result_data_sizes):
dataset.set_data_size(*data_size)
assert next(dataset).size() == result_data_size
Expand Down

0 comments on commit ff4af76

Please sign in to comment.