diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 7a46e6aa7..2c6c55f82 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -127,8 +127,10 @@ def __init__( self, metastore: "AbstractMetastore", warehouse: "AbstractWarehouse", - dataset_name: str, - dataset_version: int, + remote_ds_name: str, + remote_ds_version: int, + local_ds_name: str, + local_ds_version: int, schema: dict[str, Union[SQLType, type[SQLType]]], max_threads: int = PULL_DATASET_MAX_THREADS, ): @@ -136,8 +138,10 @@ def __init__( self._check_dependencies() self.metastore = metastore self.warehouse = warehouse - self.dataset_name = dataset_name - self.dataset_version = dataset_version + self.remote_ds_name = remote_ds_name + self.remote_ds_version = remote_ds_version + self.local_ds_name = local_ds_name + self.local_ds_version = local_ds_version self.schema = schema self.last_status_check: Optional[float] = None self.studio_client = StudioClient() @@ -171,7 +175,7 @@ def check_for_status(self) -> None: Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds """ export_status_response = self.studio_client.dataset_export_status( - self.dataset_name, self.dataset_version + self.remote_ds_name, self.remote_ds_version ) if not export_status_response.ok: raise_remote_error(export_status_response.message) @@ -203,7 +207,7 @@ def do_task(self, urls): # metastore and warehouse are not thread safe with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse: - dataset = metastore.get_dataset(self.dataset_name) + local_ds = metastore.get_dataset(self.local_ds_name) urls = list(urls) while urls: @@ -227,7 +231,7 @@ def do_task(self, urls): df = df.drop("sys__id", axis=1) inserted = warehouse.insert_dataset_rows( - df, dataset, self.dataset_version + df, local_ds, self.local_ds_version ) self.increase_counter(inserted) # type: ignore [arg-type] urls.remove(url) @@ -1101,6 +1105,13 @@ def register_dataset( def get_dataset(self, name: str) -> DatasetRecord: return self.metastore.get_dataset(name) + def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord: + """Returns dataset that contains version with specific uuid""" + for dataset in self.ls_datasets(): + if dataset.has_version_with_uuid(uuid): + return self.get_dataset(dataset.name) + raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.") + def get_remote_dataset(self, name: str) -> DatasetRecord: studio_client = StudioClient() @@ -1316,10 +1327,12 @@ def ls( for source in data_sources: # type: ignore [union-attr] yield source, source.ls(fields) - def pull_dataset( + def pull_dataset( # noqa: PLR0915 self, - dataset_uri: str, + remote_ds_uri: str, output: Optional[str] = None, + local_ds_name: Optional[str] = None, + local_ds_version: Optional[int] = None, no_cp: bool = False, force: bool = False, edatachain: bool = False, @@ -1327,105 +1340,112 @@ def pull_dataset( *, client_config=None, ) -> None: - # TODO add progress bar https://github.com/iterative/dvcx/issues/750 - # TODO copy correct remote dates https://github.com/iterative/dvcx/issues/new - # TODO compare dataset stats on remote vs local pull to assert it's ok - def _instantiate_dataset(): + def _instantiate(ds_uri: str) -> None: if no_cp: return + assert output self.cp( - [dataset_uri], + [ds_uri], output, force=force, no_edatachain_file=not edatachain, edatachain_file=edatachain_file, client_config=client_config, ) - print(f"Dataset {dataset_uri} instantiated locally to {output}") + print(f"Dataset {ds_uri} instantiated locally to {output}") if not output and not no_cp: raise ValueError("Please provide output directory for instantiation") - client_config = client_config or self.client_config - studio_client = StudioClient() try: - remote_dataset_name, version = parse_dataset_uri(dataset_uri) + remote_ds_name, version = parse_dataset_uri(remote_ds_uri) except Exception as e: raise DataChainError("Error when parsing dataset uri") from e - dataset = None - try: - dataset = self.get_dataset(remote_dataset_name) - except DatasetNotFoundError: - # we will create new one if it doesn't exist - pass - - if dataset and version and dataset.has_version(version): - """No need to communicate with Studio at all""" - dataset_uri = create_dataset_uri(remote_dataset_name, version) - print(f"Local copy of dataset {dataset_uri} already present") - _instantiate_dataset() - return + remote_ds = self.get_remote_dataset(remote_ds_name) - remote_dataset = self.get_remote_dataset(remote_dataset_name) - # if version is not specified in uri, take the latest one - if not version: - version = remote_dataset.latest_version - print(f"Version not specified, pulling the latest one (v{version})") - # updating dataset uri with latest version - dataset_uri = create_dataset_uri(remote_dataset_name, version) + try: + # if version is not specified in uri, take the latest one + if not version: + version = remote_ds.latest_version + print(f"Version not specified, pulling the latest one (v{version})") + # updating dataset uri with latest version + remote_ds_uri = create_dataset_uri(remote_ds_name, version) + remote_ds_version = remote_ds.get_version(version) + except (DatasetVersionNotFoundError, StopIteration) as exc: + raise DataChainError( + f"Dataset {remote_ds_name} doesn't have version {version} on server" + ) from exc - assert version + local_ds_name = local_ds_name or remote_ds.name + local_ds_version = local_ds_version or remote_ds_version.version + local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version) - if dataset and dataset.has_version(version): - print(f"Local copy of dataset {dataset_uri} already present") - _instantiate_dataset() + try: + # try to find existing dataset with the same uuid to avoid pulling again + existing_ds = self.get_dataset_with_version_uuid(remote_ds_version.uuid) + existing_ds_version = existing_ds.get_version_by_uuid( + remote_ds_version.uuid + ) + existing_ds_uri = create_dataset_uri( + existing_ds.name, existing_ds_version.version + ) + if existing_ds_uri == remote_ds_uri: + print(f"Local copy of dataset {remote_ds_uri} already present") + else: + print( + f"Local copy of dataset {remote_ds_uri} already present as" + f" dataset {existing_ds_uri}" + ) + _instantiate(existing_ds_uri) return + except DatasetNotFoundError: + pass try: - remote_dataset_version = remote_dataset.get_version(version) - except (DatasetVersionNotFoundError, StopIteration) as exc: - raise DataChainError( - f"Dataset {remote_dataset_name} doesn't have version {version}" - " on server" - ) from exc + local_dataset = self.get_dataset(local_ds_name) + if local_dataset and local_dataset.has_version(local_ds_version): + raise DataChainError( + f"Local dataset {local_ds_uri} already exists with different uuid," + " please choose different local dataset name or version" + ) + except DatasetNotFoundError: + pass - stats_response = studio_client.dataset_stats(remote_dataset_name, version) + stats_response = studio_client.dataset_stats( + remote_ds_name, remote_ds_version.version + ) if not stats_response.ok: raise_remote_error(stats_response.message) - dataset_stats = stats_response.data + ds_stats = stats_response.data dataset_save_progress_bar = tqdm( - desc=f"Saving dataset {dataset_uri} locally: ", + desc=f"Saving dataset {remote_ds_uri} locally: ", unit=" rows", unit_scale=True, unit_divisor=1000, - total=dataset_stats.num_objects, # type: ignore [union-attr] + total=ds_stats.num_objects, # type: ignore [union-attr] ) - schema = DatasetRecord.parse_schema(remote_dataset_version.schema) + schema = DatasetRecord.parse_schema(remote_ds_version.schema) - columns = tuple( - sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id" - ) - # creating new dataset (version) locally - dataset = self.create_dataset( - remote_dataset_name, - version, - query_script=remote_dataset_version.query_script, + local_ds = self.create_dataset( + local_ds_name, + local_ds_version, + query_script=remote_ds_version.query_script, create_rows=True, - columns=columns, - feature_schema=remote_dataset_version.feature_schema, + columns=tuple(sa.Column(n, t) for n, t in schema.items() if n != "sys__id"), + feature_schema=remote_ds_version.feature_schema, validate_version=False, - uuid=remote_dataset_version.uuid, + uuid=remote_ds_version.uuid, ) # asking remote to export dataset rows table to s3 and to return signed # urls of exported parts, which are in parquet format export_response = studio_client.export_dataset_table( - remote_dataset_name, version + remote_ds_name, remote_ds_version.version ) if not export_response.ok: raise_remote_error(export_response.message) @@ -1442,8 +1462,10 @@ def _instantiate_dataset(): rows_fetcher = DatasetRowsFetcher( metastore, warehouse, - dataset.name, - version, + remote_ds_name, + remote_ds_version.version, + local_ds_name, + local_ds_version, schema, ) try: @@ -1455,23 +1477,23 @@ def _instantiate_dataset(): dataset_save_progress_bar, ) except: - self.remove_dataset(dataset.name, version) + self.remove_dataset(local_ds_name, local_ds_version) raise - dataset = self.metastore.update_dataset_status( - dataset, + local_ds = self.metastore.update_dataset_status( + local_ds, DatasetStatus.COMPLETE, - version=version, - error_message=remote_dataset.error_message, - error_stack=remote_dataset.error_stack, - script_output=remote_dataset.error_stack, + version=local_ds_version, + error_message=remote_ds.error_message, + error_stack=remote_ds.error_stack, + script_output=remote_ds.error_stack, ) - self.update_dataset_version_with_warehouse_info(dataset, version) + self.update_dataset_version_with_warehouse_info(local_ds, local_ds_version) dataset_save_progress_bar.close() - print(f"Dataset {dataset_uri} saved locally") + print(f"Dataset {remote_ds_uri} saved locally") - _instantiate_dataset() + _instantiate(local_ds_uri) def clone( self, diff --git a/src/datachain/cli.py b/src/datachain/cli.py index 298b500fe..efb1d332b 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -400,6 +400,18 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 "--edatachain-file", help="Use a different filename for the resulting .edatachain file", ) + parse_pull.add_argument( + "--local-name", + action="store", + default=None, + help="Name of the local dataset", + ) + parse_pull.add_argument( + "--local-version", + action="store", + default=None, + help="Version of the local dataset", + ) parse_edit_dataset = subp.add_parser( "edit-dataset", parents=[parent_parser], description="Edit dataset metadata" @@ -1207,6 +1219,8 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09 catalog.pull_dataset( args.dataset, args.output, + local_ds_name=args.local_name, + local_ds_version=args.local_version, no_cp=args.no_cp, force=bool(args.force), edatachain=args.edatachain, diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 8d2a5e428..b2d12b611 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -488,6 +488,18 @@ def get_version(self, version: int) -> DatasetVersion: if v.version == version ) + def get_version_by_uuid(self, uuid: str) -> DatasetVersion: + try: + return next( + v + for v in self.versions # type: ignore [union-attr] + if v.uuid == uuid + ) + except StopIteration: + raise DatasetVersionNotFoundError( + f"Dataset {self.name} does not have version with uuid {uuid}" + ) from None + def remove_version(self, version: int) -> None: if not self.versions or not self.has_version(version): return @@ -635,6 +647,9 @@ def is_bucket_listing(self) -> bool: LISTING_PREFIX ) + def has_version_with_uuid(self, uuid: str) -> bool: + return any(v.uuid == uuid for v in self.versions) + class RowDict(dict): pass diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 9717aa772..5cb1c5889 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -9,7 +9,7 @@ from datachain.client.fsspec import Client from datachain.config import Config, ConfigLevel from datachain.dataset import DatasetStatus -from datachain.error import DataChainError +from datachain.error import DataChainError, DatasetNotFoundError from datachain.utils import STUDIO_URL, JSONSerialize from tests.data import ENTRIES from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path @@ -120,7 +120,7 @@ def remote_dataset_version(schema, dataset_rows): def remote_dataset(remote_dataset_version, schema): return { "id": 1, - "name": "remote", + "name": "dogs", "description": "", "labels": [], "schema": schema, @@ -141,34 +141,68 @@ def remote_dataset(remote_dataset_version, schema): } -@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) -@pytest.mark.parametrize("dataset_uri", ["ds://dogs@v1", "ds://dogs"]) -@pytest.mark.parametrize("instantiate", [True, False]) -@skip_if_not_sqlite -def test_pull_dataset_success( - requests_mock, - cloud_test_catalog, - remote_dataset, - dog_entries_parquet_lz4, - dataset_uri, - instantiate, -): - src_uri = cloud_test_catalog.src_uri - working_dir = cloud_test_catalog.working_dir - data_url = ( +@pytest.fixture +def remote_dataset_chunk_url(): + return ( "https://studio-blobvault.s3.amazonaws.com/datachain_ds_export_1_0.parquet.lz4" ) + + +@pytest.fixture +def remote_dataset_info(requests_mock, remote_dataset): requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-info", json=remote_dataset) + + +@pytest.fixture +def remote_dataset_stats(requests_mock): requests_mock.post( f"{STUDIO_URL}/api/datachain/dataset-stats", json={"num_objects": 5, "size": 1000}, ) - requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-export", json=[data_url]) + + +@pytest.fixture +def dataset_export(requests_mock, remote_dataset_chunk_url): + requests_mock.post( + f"{STUDIO_URL}/api/datachain/dataset-export", json=[remote_dataset_chunk_url] + ) + + +@pytest.fixture +def dataset_export_status(requests_mock): requests_mock.post( f"{STUDIO_URL}/api/datachain/dataset-export-status", json={"status": "completed"}, ) - requests_mock.get(data_url, content=dog_entries_parquet_lz4) + + +@pytest.fixture +def dataset_export_data_chunk( + requests_mock, remote_dataset_chunk_url, dog_entries_parquet_lz4 +): + requests_mock.get(remote_dataset_chunk_url, content=dog_entries_parquet_lz4) + + +@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) +@pytest.mark.parametrize("dataset_uri", ["ds://dogs@v1", "ds://dogs"]) +@pytest.mark.parametrize("local_ds_name", [None, "other"]) +@pytest.mark.parametrize("local_ds_version", [None, 2]) +@pytest.mark.parametrize("instantiate", [True, False]) +@skip_if_not_sqlite +def test_pull_dataset_success( + cloud_test_catalog, + remote_dataset_info, + remote_dataset_stats, + dataset_export, + dataset_export_status, + dataset_export_data_chunk, + dataset_uri, + local_ds_name, + local_ds_version, + instantiate, +): + src_uri = cloud_test_catalog.src_uri + working_dir = cloud_test_catalog.working_dir catalog = cloud_test_catalog.catalog dest = None @@ -176,19 +210,30 @@ def test_pull_dataset_success( if instantiate: dest = working_dir / "data" dest.mkdir() - catalog.pull_dataset(dataset_uri, output=str(dest), no_cp=False) + catalog.pull_dataset( + dataset_uri, + output=str(dest), + local_ds_name=local_ds_name, + local_ds_version=local_ds_version, + no_cp=False, + ) else: # trying to pull multiple times since that should work as well - catalog.pull_dataset(dataset_uri, no_cp=True) - catalog.pull_dataset(dataset_uri, no_cp=True) - - dataset = catalog.get_dataset("dogs") - assert dataset.versions_values == [1] + for _ in range(2): + catalog.pull_dataset( + dataset_uri, + local_ds_name=local_ds_name, + local_ds_version=local_ds_version, + no_cp=True, + ) + + dataset = catalog.get_dataset(local_ds_name or "dogs") + assert dataset.versions_values == [local_ds_version or 1] assert dataset.status == DatasetStatus.COMPLETE assert dataset.created_at assert dataset.finished_at assert dataset.schema - dataset_version = dataset.get_version(1) + dataset_version = dataset.get_version(local_ds_version or 1) assert dataset_version.status == DatasetStatus.COMPLETE assert dataset_version.created_at assert dataset_version.finished_at @@ -200,7 +245,7 @@ def test_pull_dataset_success( assert_row_names( catalog, dataset, - 1, + local_ds_version or 1, { "dog1", "dog2", @@ -230,7 +275,6 @@ def test_pull_dataset_wrong_dataset_uri_format( requests_mock, cloud_test_catalog, remote_dataset, - dog_entries_parquet_lz4, ): catalog = cloud_test_catalog.catalog @@ -244,12 +288,8 @@ def test_pull_dataset_wrong_dataset_uri_format( def test_pull_dataset_wrong_version( requests_mock, cloud_test_catalog, - remote_dataset, + remote_dataset_info, ): - requests_mock.post( - f"{STUDIO_URL}/api/datachain/dataset-info", - json=remote_dataset, - ) catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: @@ -262,7 +302,6 @@ def test_pull_dataset_wrong_version( def test_pull_dataset_not_found_in_remote( requests_mock, cloud_test_catalog, - remote_dataset, ): requests_mock.post( f"{STUDIO_URL}/api/datachain/dataset-info", @@ -281,12 +320,8 @@ def test_pull_dataset_not_found_in_remote( def test_pull_dataset_error_on_fetching_stats( requests_mock, cloud_test_catalog, - remote_dataset, + remote_dataset_info, ): - requests_mock.post( - f"{STUDIO_URL}/api/datachain/dataset-info", - json=remote_dataset, - ) requests_mock.post( f"{STUDIO_URL}/api/datachain/dataset-stats", status_code=400, @@ -305,18 +340,11 @@ def test_pull_dataset_error_on_fetching_stats( def test_pull_dataset_exporting_dataset_failed_in_remote( requests_mock, cloud_test_catalog, - remote_dataset, + remote_dataset_info, + remote_dataset_stats, + dataset_export, export_status, ): - data_url = ( - "https://studio-blobvault.s3.amazonaws.com/datachain_ds_export_1_0.parquet.lz4" - ) - requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-info", json=remote_dataset) - requests_mock.post( - f"{STUDIO_URL}/api/datachain/dataset-stats", - json={"num_objects": 5, "size": 1000}, - ) - requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-export", json=[data_url]) requests_mock.post( f"{STUDIO_URL}/api/datachain/dataset-export-status", json={"status": export_status}, @@ -336,24 +364,72 @@ def test_pull_dataset_exporting_dataset_failed_in_remote( def test_pull_dataset_empty_parquet( requests_mock, cloud_test_catalog, - remote_dataset, - dog_entries_parquet_lz4, + remote_dataset_info, + remote_dataset_stats, + dataset_export, + dataset_export_status, + remote_dataset_chunk_url, ): - data_url = ( - "https://studio-blobvault.s3.amazonaws.com/datachain_ds_export_1_0.parquet.lz4" - ) - requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-info", json=remote_dataset) - requests_mock.post( - f"{STUDIO_URL}/api/datachain/dataset-stats", - json={"num_objects": 5, "size": 1000}, - ) - requests_mock.post(f"{STUDIO_URL}/api/datachain/dataset-export", json=[data_url]) - requests_mock.post( - f"{STUDIO_URL}/api/datachain/dataset-export-status", - json={"status": "completed"}, - ) - requests_mock.get(data_url, content=b"") + requests_mock.get(remote_dataset_chunk_url, content=b"") catalog = cloud_test_catalog.catalog with pytest.raises(RuntimeError): catalog.pull_dataset("ds://dogs@v1", no_cp=True) + + +@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) +@skip_if_not_sqlite +def test_pull_dataset_already_exists_locally( + cloud_test_catalog, + remote_dataset_info, + remote_dataset_stats, + dataset_export, + dataset_export_status, + dataset_export_data_chunk, +): + catalog = cloud_test_catalog.catalog + + catalog.pull_dataset("ds://dogs@v1", local_ds_name="other", no_cp=True) + catalog.pull_dataset("ds://dogs@v1", no_cp=True) + + other = catalog.get_dataset("other") + other_version = other.get_version(1) + assert other_version.uuid == DATASET_UUID + assert other_version.num_objects == 4 + assert other_version.size == 15 + + # dataset with same uuid created only once, on first pull with local name "other" + with pytest.raises(DatasetNotFoundError): + catalog.get_dataset("dogs") + + +@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) +@pytest.mark.parametrize("local_ds_name", [None, "other"]) +@skip_if_not_sqlite +def test_pull_dataset_local_name_already_exists( + cloud_test_catalog, + remote_dataset_info, + remote_dataset_stats, + dataset_export, + dataset_export_status, + dataset_export_data_chunk, + local_ds_name, +): + catalog = cloud_test_catalog.catalog + src_uri = cloud_test_catalog.src_uri + + catalog.create_dataset_from_sources( + local_ds_name or "dogs", [f"{src_uri}/dogs/*"], recursive=True + ) + with pytest.raises(DataChainError) as exc_info: + catalog.pull_dataset("ds://dogs@v1", local_ds_name=local_ds_name, no_cp=True) + + assert str(exc_info.value) == ( + f'Local dataset ds://{local_ds_name or "dogs"}@v1 already exists with different' + ' uuid, please choose different local dataset name or version' + ) + + # able to save it as version 2 of local dataset name + catalog.pull_dataset( + "ds://dogs@v1", local_ds_name=local_ds_name, local_ds_version=2, no_cp=True + )