Skip to content

Commit

Permalink
Using version uuid in datachain pull (#621)
Browse files Browse the repository at this point in the history
Using uuid in datachain pull
  • Loading branch information
ilongin authored Nov 29, 2024
1 parent be41789 commit 045f3b0
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 142 deletions.
176 changes: 99 additions & 77 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,21 @@ def __init__(
self,
metastore: "AbstractMetastore",
warehouse: "AbstractWarehouse",
dataset_name: str,
dataset_version: int,
remote_ds_name: str,
remote_ds_version: int,
local_ds_name: str,
local_ds_version: int,
schema: dict[str, Union[SQLType, type[SQLType]]],
max_threads: int = PULL_DATASET_MAX_THREADS,
):
super().__init__(max_threads)
self._check_dependencies()
self.metastore = metastore
self.warehouse = warehouse
self.dataset_name = dataset_name
self.dataset_version = dataset_version
self.remote_ds_name = remote_ds_name
self.remote_ds_version = remote_ds_version
self.local_ds_name = local_ds_name
self.local_ds_version = local_ds_version
self.schema = schema
self.last_status_check: Optional[float] = None
self.studio_client = StudioClient()
Expand Down Expand Up @@ -171,7 +175,7 @@ def check_for_status(self) -> None:
Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds
"""
export_status_response = self.studio_client.dataset_export_status(
self.dataset_name, self.dataset_version
self.remote_ds_name, self.remote_ds_version
)
if not export_status_response.ok:
raise_remote_error(export_status_response.message)
Expand Down Expand Up @@ -203,7 +207,7 @@ def do_task(self, urls):

# metastore and warehouse are not thread safe
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
dataset = metastore.get_dataset(self.dataset_name)
local_ds = metastore.get_dataset(self.local_ds_name)

urls = list(urls)
while urls:
Expand All @@ -227,7 +231,7 @@ def do_task(self, urls):
df = df.drop("sys__id", axis=1)

inserted = warehouse.insert_dataset_rows(
df, dataset, self.dataset_version
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)
Expand Down Expand Up @@ -1101,6 +1105,13 @@ def register_dataset(
def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
"""Returns dataset that contains version with specific uuid"""
for dataset in self.ls_datasets():
if dataset.has_version_with_uuid(uuid):
return self.get_dataset(dataset.name)
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")

def get_remote_dataset(self, name: str) -> DatasetRecord:
studio_client = StudioClient()

Expand Down Expand Up @@ -1316,116 +1327,125 @@ def ls(
for source in data_sources: # type: ignore [union-attr]
yield source, source.ls(fields)

def pull_dataset(
def pull_dataset( # noqa: PLR0915
self,
dataset_uri: str,
remote_ds_uri: str,
output: Optional[str] = None,
local_ds_name: Optional[str] = None,
local_ds_version: Optional[int] = None,
no_cp: bool = False,
force: bool = False,
edatachain: bool = False,
edatachain_file: Optional[str] = None,
*,
client_config=None,
) -> None:
# TODO add progress bar https://github.com/iterative/dvcx/issues/750
# TODO copy correct remote dates https://github.com/iterative/dvcx/issues/new
# TODO compare dataset stats on remote vs local pull to assert it's ok
def _instantiate_dataset():
def _instantiate(ds_uri: str) -> None:
if no_cp:
return
assert output
self.cp(
[dataset_uri],
[ds_uri],
output,
force=force,
no_edatachain_file=not edatachain,
edatachain_file=edatachain_file,
client_config=client_config,
)
print(f"Dataset {dataset_uri} instantiated locally to {output}")
print(f"Dataset {ds_uri} instantiated locally to {output}")

if not output and not no_cp:
raise ValueError("Please provide output directory for instantiation")

client_config = client_config or self.client_config

studio_client = StudioClient()

try:
remote_dataset_name, version = parse_dataset_uri(dataset_uri)
remote_ds_name, version = parse_dataset_uri(remote_ds_uri)
except Exception as e:
raise DataChainError("Error when parsing dataset uri") from e

dataset = None
try:
dataset = self.get_dataset(remote_dataset_name)
except DatasetNotFoundError:
# we will create new one if it doesn't exist
pass

if dataset and version and dataset.has_version(version):
"""No need to communicate with Studio at all"""
dataset_uri = create_dataset_uri(remote_dataset_name, version)
print(f"Local copy of dataset {dataset_uri} already present")
_instantiate_dataset()
return
remote_ds = self.get_remote_dataset(remote_ds_name)

remote_dataset = self.get_remote_dataset(remote_dataset_name)
# if version is not specified in uri, take the latest one
if not version:
version = remote_dataset.latest_version
print(f"Version not specified, pulling the latest one (v{version})")
# updating dataset uri with latest version
dataset_uri = create_dataset_uri(remote_dataset_name, version)
try:
# if version is not specified in uri, take the latest one
if not version:
version = remote_ds.latest_version
print(f"Version not specified, pulling the latest one (v{version})")
# updating dataset uri with latest version
remote_ds_uri = create_dataset_uri(remote_ds_name, version)
remote_ds_version = remote_ds.get_version(version)
except (DatasetVersionNotFoundError, StopIteration) as exc:
raise DataChainError(
f"Dataset {remote_ds_name} doesn't have version {version} on server"
) from exc

assert version
local_ds_name = local_ds_name or remote_ds.name
local_ds_version = local_ds_version or remote_ds_version.version
local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version)

if dataset and dataset.has_version(version):
print(f"Local copy of dataset {dataset_uri} already present")
_instantiate_dataset()
try:
# try to find existing dataset with the same uuid to avoid pulling again
existing_ds = self.get_dataset_with_version_uuid(remote_ds_version.uuid)
existing_ds_version = existing_ds.get_version_by_uuid(
remote_ds_version.uuid
)
existing_ds_uri = create_dataset_uri(
existing_ds.name, existing_ds_version.version
)
if existing_ds_uri == remote_ds_uri:
print(f"Local copy of dataset {remote_ds_uri} already present")
else:
print(
f"Local copy of dataset {remote_ds_uri} already present as"
f" dataset {existing_ds_uri}"
)
_instantiate(existing_ds_uri)
return
except DatasetNotFoundError:
pass

try:
remote_dataset_version = remote_dataset.get_version(version)
except (DatasetVersionNotFoundError, StopIteration) as exc:
raise DataChainError(
f"Dataset {remote_dataset_name} doesn't have version {version}"
" on server"
) from exc
local_dataset = self.get_dataset(local_ds_name)
if local_dataset and local_dataset.has_version(local_ds_version):
raise DataChainError(
f"Local dataset {local_ds_uri} already exists with different uuid,"
" please choose different local dataset name or version"
)
except DatasetNotFoundError:
pass

stats_response = studio_client.dataset_stats(remote_dataset_name, version)
stats_response = studio_client.dataset_stats(
remote_ds_name, remote_ds_version.version
)
if not stats_response.ok:
raise_remote_error(stats_response.message)
dataset_stats = stats_response.data
ds_stats = stats_response.data

dataset_save_progress_bar = tqdm(
desc=f"Saving dataset {dataset_uri} locally: ",
desc=f"Saving dataset {remote_ds_uri} locally: ",
unit=" rows",
unit_scale=True,
unit_divisor=1000,
total=dataset_stats.num_objects, # type: ignore [union-attr]
total=ds_stats.num_objects, # type: ignore [union-attr]
)

schema = DatasetRecord.parse_schema(remote_dataset_version.schema)
schema = DatasetRecord.parse_schema(remote_ds_version.schema)

columns = tuple(
sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id"
)
# creating new dataset (version) locally
dataset = self.create_dataset(
remote_dataset_name,
version,
query_script=remote_dataset_version.query_script,
local_ds = self.create_dataset(
local_ds_name,
local_ds_version,
query_script=remote_ds_version.query_script,
create_rows=True,
columns=columns,
feature_schema=remote_dataset_version.feature_schema,
columns=tuple(sa.Column(n, t) for n, t in schema.items() if n != "sys__id"),
feature_schema=remote_ds_version.feature_schema,
validate_version=False,
uuid=remote_dataset_version.uuid,
uuid=remote_ds_version.uuid,
)

# asking remote to export dataset rows table to s3 and to return signed
# urls of exported parts, which are in parquet format
export_response = studio_client.export_dataset_table(
remote_dataset_name, version
remote_ds_name, remote_ds_version.version
)
if not export_response.ok:
raise_remote_error(export_response.message)
Expand All @@ -1442,8 +1462,10 @@ def _instantiate_dataset():
rows_fetcher = DatasetRowsFetcher(
metastore,
warehouse,
dataset.name,
version,
remote_ds_name,
remote_ds_version.version,
local_ds_name,
local_ds_version,
schema,
)
try:
Expand All @@ -1455,23 +1477,23 @@ def _instantiate_dataset():
dataset_save_progress_bar,
)
except:
self.remove_dataset(dataset.name, version)
self.remove_dataset(local_ds_name, local_ds_version)
raise

dataset = self.metastore.update_dataset_status(
dataset,
local_ds = self.metastore.update_dataset_status(
local_ds,
DatasetStatus.COMPLETE,
version=version,
error_message=remote_dataset.error_message,
error_stack=remote_dataset.error_stack,
script_output=remote_dataset.error_stack,
version=local_ds_version,
error_message=remote_ds.error_message,
error_stack=remote_ds.error_stack,
script_output=remote_ds.error_stack,
)
self.update_dataset_version_with_warehouse_info(dataset, version)
self.update_dataset_version_with_warehouse_info(local_ds, local_ds_version)

dataset_save_progress_bar.close()
print(f"Dataset {dataset_uri} saved locally")
print(f"Dataset {remote_ds_uri} saved locally")

_instantiate_dataset()
_instantiate(local_ds_uri)

def clone(
self,
Expand Down
14 changes: 14 additions & 0 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"--edatachain-file",
help="Use a different filename for the resulting .edatachain file",
)
parse_pull.add_argument(
"--local-name",
action="store",
default=None,
help="Name of the local dataset",
)
parse_pull.add_argument(
"--local-version",
action="store",
default=None,
help="Version of the local dataset",
)

parse_edit_dataset = subp.add_parser(
"edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
Expand Down Expand Up @@ -1207,6 +1219,8 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
catalog.pull_dataset(
args.dataset,
args.output,
local_ds_name=args.local_name,
local_ds_version=args.local_version,
no_cp=args.no_cp,
force=bool(args.force),
edatachain=args.edatachain,
Expand Down
15 changes: 15 additions & 0 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,18 @@ def get_version(self, version: int) -> DatasetVersion:
if v.version == version
)

def get_version_by_uuid(self, uuid: str) -> DatasetVersion:
try:
return next(
v
for v in self.versions # type: ignore [union-attr]
if v.uuid == uuid
)
except StopIteration:
raise DatasetVersionNotFoundError(
f"Dataset {self.name} does not have version with uuid {uuid}"
) from None

def remove_version(self, version: int) -> None:
if not self.versions or not self.has_version(version):
return
Expand Down Expand Up @@ -635,6 +647,9 @@ def is_bucket_listing(self) -> bool:
LISTING_PREFIX
)

def has_version_with_uuid(self, uuid: str) -> bool:
return any(v.uuid == uuid for v in self.versions)


class RowDict(dict):
pass
Loading

0 comments on commit 045f3b0

Please sign in to comment.