Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using version uuid in datachain pull #621

Merged
merged 6 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 99 additions & 77 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,21 @@ def __init__(
self,
metastore: "AbstractMetastore",
warehouse: "AbstractWarehouse",
dataset_name: str,
dataset_version: int,
remote_ds_name: str,
remote_ds_version: int,
local_ds_name: str,
local_ds_version: int,
schema: dict[str, Union[SQLType, type[SQLType]]],
max_threads: int = PULL_DATASET_MAX_THREADS,
):
super().__init__(max_threads)
self._check_dependencies()
self.metastore = metastore
self.warehouse = warehouse
self.dataset_name = dataset_name
self.dataset_version = dataset_version
self.remote_ds_name = remote_ds_name
self.remote_ds_version = remote_ds_version
self.local_ds_name = local_ds_name
self.local_ds_version = local_ds_version
self.schema = schema
self.last_status_check: Optional[float] = None
self.studio_client = StudioClient()
Expand Down Expand Up @@ -171,7 +175,7 @@ def check_for_status(self) -> None:
Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds
"""
export_status_response = self.studio_client.dataset_export_status(
self.dataset_name, self.dataset_version
self.remote_ds_name, self.remote_ds_version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Q] Why not use uuid instead of remote_ds_name + remote_ds_version & local_ds_name + local_ds_version?

Copy link
Contributor Author

@ilongin ilongin Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, there are other APIs that use name + version as well, but I would leave that to followup issue not to bloat this one too much.
In general, user does use explicit name and version when initiating pull, e.g datachain pull ds://dogs@v3 but the use of uuid in the process after initial get dataset info would fix potential issue of someone renaming remote dataset right in the middle of pull process (not sue how likely is this to happen but it's still the potential risk)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of this needs to go into this PR but I wanted to mention it and this seems like an ok place.

I would advocate for hiding local/remote name/version from users as much as possible. (From experience) I think the concept will only confuse users. Only exposing it when we absolutely need to makes sense to me

For example:

If the user tries to pull cat@v1 under the following circumstances we should throw an error stating there is a dataset with that name already and the uuid doesn't match. Ask them to rename the local dataset (as SaaS is the source of truth):

name version uuid local remote
cat 1 cf308a6b-0e66-477f-ba76-9c3f8ffd27fe X
cat 1 d1c1ac59-5763-4069-a328-8292ec6a6205 X

I think we should update the output of ls datasets to have the same format as the above or at least give it the ability to show these discrepancies somehow.

I am now second-guessing myself and wondering whether the local/remote name and version are that bad.

Copy link
Contributor Author

@ilongin ilongin Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how it works in this PR:

  1. If local dataset with same uuid exists -> skip pulling and printing "Local copy of dataset cats@v1 already present". If it has different name locally it says "Local copy of dataset cats@v1 already present as dataset other_cats@v5"
  2. If we don't find local dataset with same uuid and there is already local dataset name + version present that holds completely different data we throw error with message "Local dataset cats@v1 already exists, please choose different local dataset name or version"

I didn't want to mention uuid anywhere as I'm not sure yet if that's something we should bother user with, but maybe I could enhance message from 2) by mentioning it ... "

EDIT: Changed it to "Local dataset cats@v1 already exists with different uuid, please choose different local dataset name or version"

)
if not export_status_response.ok:
raise_remote_error(export_status_response.message)
Expand Down Expand Up @@ -203,7 +207,7 @@ def do_task(self, urls):

# metastore and warehouse are not thread safe
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
dataset = metastore.get_dataset(self.dataset_name)
local_ds = metastore.get_dataset(self.local_ds_name)

urls = list(urls)
while urls:
Expand All @@ -227,7 +231,7 @@ def do_task(self, urls):
df = df.drop("sys__id", axis=1)

inserted = warehouse.insert_dataset_rows(
df, dataset, self.dataset_version
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)
Expand Down Expand Up @@ -1101,6 +1105,13 @@ def register_dataset(
def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
"""Returns dataset that contains version with specific uuid"""
for dataset in self.ls_datasets():
if dataset.has_version_with_uuid(uuid):
return self.get_dataset(dataset.name)
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")

def get_remote_dataset(self, name: str) -> DatasetRecord:
studio_client = StudioClient()

Expand Down Expand Up @@ -1316,116 +1327,125 @@ def ls(
for source in data_sources: # type: ignore [union-attr]
yield source, source.ls(fields)

def pull_dataset(
def pull_dataset( # noqa: PLR0915
self,
dataset_uri: str,
remote_ds_uri: str,
output: Optional[str] = None,
local_ds_name: Optional[str] = None,
local_ds_version: Optional[int] = None,
no_cp: bool = False,
force: bool = False,
edatachain: bool = False,
edatachain_file: Optional[str] = None,
*,
client_config=None,
) -> None:
# TODO add progress bar https://github.com/iterative/dvcx/issues/750
# TODO copy correct remote dates https://github.com/iterative/dvcx/issues/new
# TODO compare dataset stats on remote vs local pull to assert it's ok
def _instantiate_dataset():
def _instantiate(ds_uri: str) -> None:
if no_cp:
return
assert output
self.cp(
[dataset_uri],
[ds_uri],
output,
force=force,
no_edatachain_file=not edatachain,
edatachain_file=edatachain_file,
client_config=client_config,
)
print(f"Dataset {dataset_uri} instantiated locally to {output}")
print(f"Dataset {ds_uri} instantiated locally to {output}")

if not output and not no_cp:
raise ValueError("Please provide output directory for instantiation")

client_config = client_config or self.client_config

studio_client = StudioClient()

try:
remote_dataset_name, version = parse_dataset_uri(dataset_uri)
remote_ds_name, version = parse_dataset_uri(remote_ds_uri)
except Exception as e:
raise DataChainError("Error when parsing dataset uri") from e

dataset = None
try:
dataset = self.get_dataset(remote_dataset_name)
except DatasetNotFoundError:
# we will create new one if it doesn't exist
pass

if dataset and version and dataset.has_version(version):
"""No need to communicate with Studio at all"""
dataset_uri = create_dataset_uri(remote_dataset_name, version)
print(f"Local copy of dataset {dataset_uri} already present")
_instantiate_dataset()
return
remote_ds = self.get_remote_dataset(remote_ds_name)

remote_dataset = self.get_remote_dataset(remote_dataset_name)
# if version is not specified in uri, take the latest one
if not version:
version = remote_dataset.latest_version
print(f"Version not specified, pulling the latest one (v{version})")
# updating dataset uri with latest version
dataset_uri = create_dataset_uri(remote_dataset_name, version)
try:
# if version is not specified in uri, take the latest one
if not version:
version = remote_ds.latest_version
print(f"Version not specified, pulling the latest one (v{version})")
# updating dataset uri with latest version
remote_ds_uri = create_dataset_uri(remote_ds_name, version)
remote_ds_version = remote_ds.get_version(version)
except (DatasetVersionNotFoundError, StopIteration) as exc:
raise DataChainError(
f"Dataset {remote_ds_name} doesn't have version {version} on server"
) from exc

assert version
local_ds_name = local_ds_name or remote_ds.name
local_ds_version = local_ds_version or remote_ds_version.version
local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version)

if dataset and dataset.has_version(version):
print(f"Local copy of dataset {dataset_uri} already present")
_instantiate_dataset()
try:
# try to find existing dataset with the same uuid to avoid pulling again
existing_ds = self.get_dataset_with_version_uuid(remote_ds_version.uuid)
existing_ds_version = existing_ds.get_version_by_uuid(
remote_ds_version.uuid
)
existing_ds_uri = create_dataset_uri(
existing_ds.name, existing_ds_version.version
)
if existing_ds_uri == remote_ds_uri:
print(f"Local copy of dataset {remote_ds_uri} already present")
else:
print(
f"Local copy of dataset {remote_ds_uri} already present as"
f" dataset {existing_ds_uri}"
)
_instantiate(existing_ds_uri)
return
except DatasetNotFoundError:
pass

try:
remote_dataset_version = remote_dataset.get_version(version)
except (DatasetVersionNotFoundError, StopIteration) as exc:
raise DataChainError(
f"Dataset {remote_dataset_name} doesn't have version {version}"
" on server"
) from exc
local_dataset = self.get_dataset(local_ds_name)
if local_dataset and local_dataset.has_version(local_ds_version):
raise DataChainError(
f"Local dataset {local_ds_uri} already exists with different uuid,"
" please choose different local dataset name or version"
)
except DatasetNotFoundError:
pass

stats_response = studio_client.dataset_stats(remote_dataset_name, version)
stats_response = studio_client.dataset_stats(
remote_ds_name, remote_ds_version.version
)
if not stats_response.ok:
raise_remote_error(stats_response.message)
dataset_stats = stats_response.data
ds_stats = stats_response.data

dataset_save_progress_bar = tqdm(
desc=f"Saving dataset {dataset_uri} locally: ",
desc=f"Saving dataset {remote_ds_uri} locally: ",
unit=" rows",
unit_scale=True,
unit_divisor=1000,
total=dataset_stats.num_objects, # type: ignore [union-attr]
total=ds_stats.num_objects, # type: ignore [union-attr]
)

schema = DatasetRecord.parse_schema(remote_dataset_version.schema)
schema = DatasetRecord.parse_schema(remote_ds_version.schema)

columns = tuple(
sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id"
)
# creating new dataset (version) locally
dataset = self.create_dataset(
remote_dataset_name,
version,
query_script=remote_dataset_version.query_script,
local_ds = self.create_dataset(
local_ds_name,
local_ds_version,
query_script=remote_ds_version.query_script,
create_rows=True,
columns=columns,
feature_schema=remote_dataset_version.feature_schema,
columns=tuple(sa.Column(n, t) for n, t in schema.items() if n != "sys__id"),
feature_schema=remote_ds_version.feature_schema,
validate_version=False,
uuid=remote_dataset_version.uuid,
uuid=remote_ds_version.uuid,
)

# asking remote to export dataset rows table to s3 and to return signed
# urls of exported parts, which are in parquet format
export_response = studio_client.export_dataset_table(
remote_dataset_name, version
remote_ds_name, remote_ds_version.version
)
if not export_response.ok:
raise_remote_error(export_response.message)
Expand All @@ -1442,8 +1462,10 @@ def _instantiate_dataset():
rows_fetcher = DatasetRowsFetcher(
metastore,
warehouse,
dataset.name,
version,
remote_ds_name,
remote_ds_version.version,
local_ds_name,
local_ds_version,
schema,
)
try:
Expand All @@ -1455,23 +1477,23 @@ def _instantiate_dataset():
dataset_save_progress_bar,
)
except:
self.remove_dataset(dataset.name, version)
self.remove_dataset(local_ds_name, local_ds_version)
raise

dataset = self.metastore.update_dataset_status(
dataset,
local_ds = self.metastore.update_dataset_status(
local_ds,
DatasetStatus.COMPLETE,
version=version,
error_message=remote_dataset.error_message,
error_stack=remote_dataset.error_stack,
script_output=remote_dataset.error_stack,
version=local_ds_version,
error_message=remote_ds.error_message,
error_stack=remote_ds.error_stack,
script_output=remote_ds.error_stack,
)
self.update_dataset_version_with_warehouse_info(dataset, version)
self.update_dataset_version_with_warehouse_info(local_ds, local_ds_version)

dataset_save_progress_bar.close()
print(f"Dataset {dataset_uri} saved locally")
print(f"Dataset {remote_ds_uri} saved locally")

_instantiate_dataset()
_instantiate(local_ds_uri)

def clone(
self,
Expand Down
14 changes: 14 additions & 0 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"--edatachain-file",
help="Use a different filename for the resulting .edatachain file",
)
parse_pull.add_argument(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[C] I still don't know if it is better to expose these or uuid to the user

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should let client work with name / version. uuid is not really practical, and also user cannot really set it.

"--local-name",
action="store",
default=None,
help="Name of the local dataset",
)
parse_pull.add_argument(
"--local-version",
action="store",
default=None,
help="Version of the local dataset",
)

parse_edit_dataset = subp.add_parser(
"edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
Expand Down Expand Up @@ -1207,6 +1219,8 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
catalog.pull_dataset(
args.dataset,
args.output,
local_ds_name=args.local_name,
local_ds_version=args.local_version,
no_cp=args.no_cp,
force=bool(args.force),
edatachain=args.edatachain,
Expand Down
15 changes: 15 additions & 0 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,18 @@
if v.version == version
)

def get_version_by_uuid(self, uuid: str) -> DatasetVersion:
try:
return next(
v
for v in self.versions # type: ignore [union-attr]
if v.uuid == uuid
)
except StopIteration:
raise DatasetVersionNotFoundError(

Check warning on line 499 in src/datachain/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/dataset.py#L498-L499

Added lines #L498 - L499 were not covered by tests
f"Dataset {self.name} does not have version with uuid {uuid}"
) from None

def remove_version(self, version: int) -> None:
if not self.versions or not self.has_version(version):
return
Expand Down Expand Up @@ -635,6 +647,9 @@
LISTING_PREFIX
)

def has_version_with_uuid(self, uuid: str) -> bool:
return any(v.uuid == uuid for v in self.versions)


class RowDict(dict):
pass
Loading