-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using version uuid
in datachain pull
#621
Changes from 2 commits
b5afbb2
dd6e125
12ae692
7ef4352
21b62b5
80b64ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,17 +126,21 @@ def __init__( | |
self, | ||
metastore: "AbstractMetastore", | ||
warehouse: "AbstractWarehouse", | ||
dataset_name: str, | ||
dataset_version: int, | ||
remote_ds_name: str, | ||
remote_ds_version: int, | ||
local_ds_name: str, | ||
local_ds_version: int, | ||
schema: dict[str, Union[SQLType, type[SQLType]]], | ||
max_threads: int = PULL_DATASET_MAX_THREADS, | ||
): | ||
super().__init__(max_threads) | ||
self._check_dependencies() | ||
self.metastore = metastore | ||
self.warehouse = warehouse | ||
self.dataset_name = dataset_name | ||
self.dataset_version = dataset_version | ||
self.remote_ds_name = remote_ds_name | ||
self.remote_ds_version = remote_ds_version | ||
self.local_ds_name = local_ds_name | ||
self.local_ds_version = local_ds_version | ||
self.schema = schema | ||
self.last_status_check: Optional[float] = None | ||
self.studio_client = StudioClient() | ||
|
@@ -170,7 +174,7 @@ def check_for_status(self) -> None: | |
Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds | ||
""" | ||
export_status_response = self.studio_client.dataset_export_status( | ||
self.dataset_name, self.dataset_version | ||
self.remote_ds_name, self.remote_ds_version | ||
) | ||
if not export_status_response.ok: | ||
raise_remote_error(export_status_response.message) | ||
|
@@ -202,7 +206,7 @@ def do_task(self, urls): | |
|
||
# metastore and warehouse are not thread safe | ||
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse: | ||
dataset = metastore.get_dataset(self.dataset_name) | ||
local_ds = metastore.get_dataset(self.local_ds_name) | ||
|
||
urls = list(urls) | ||
while urls: | ||
|
@@ -226,7 +230,7 @@ def do_task(self, urls): | |
df = df.drop("sys__id", axis=1) | ||
|
||
inserted = warehouse.insert_dataset_rows( | ||
df, dataset, self.dataset_version | ||
df, local_ds, self.local_ds_version | ||
) | ||
self.increase_counter(inserted) # type: ignore [arg-type] | ||
urls.remove(url) | ||
|
@@ -1098,6 +1102,13 @@ def register_dataset( | |
def get_dataset(self, name: str) -> DatasetRecord: | ||
return self.metastore.get_dataset(name) | ||
|
||
def get_dataset_with_uuid(self, uuid: str) -> DatasetRecord: | ||
"""Returns dataset that contains version with specific uuid""" | ||
for dataset in self.ls_datasets(): | ||
if dataset.has_version_with_uuid(uuid): | ||
return dataset | ||
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.") | ||
|
||
def get_remote_dataset(self, name: str) -> DatasetRecord: | ||
studio_client = StudioClient() | ||
|
||
|
@@ -1313,116 +1324,125 @@ def ls( | |
for source in data_sources: # type: ignore [union-attr] | ||
yield source, source.ls(fields) | ||
|
||
def pull_dataset( | ||
def pull_dataset( # noqa: PLR0915 | ||
self, | ||
dataset_uri: str, | ||
remote_ds_uri: str, | ||
output: Optional[str] = None, | ||
local_ds_name: Optional[str] = None, | ||
local_ds_version: Optional[int] = None, | ||
no_cp: bool = False, | ||
force: bool = False, | ||
edatachain: bool = False, | ||
edatachain_file: Optional[str] = None, | ||
*, | ||
client_config=None, | ||
) -> None: | ||
# TODO add progress bar https://github.com/iterative/dvcx/issues/750 | ||
# TODO copy correct remote dates https://github.com/iterative/dvcx/issues/new | ||
# TODO compare dataset stats on remote vs local pull to assert it's ok | ||
def _instantiate_dataset(): | ||
def _instantiate(ds_uri: str) -> None: | ||
if no_cp: | ||
return | ||
assert output | ||
self.cp( | ||
[dataset_uri], | ||
[ds_uri], | ||
output, | ||
force=force, | ||
no_edatachain_file=not edatachain, | ||
edatachain_file=edatachain_file, | ||
client_config=client_config, | ||
) | ||
print(f"Dataset {dataset_uri} instantiated locally to {output}") | ||
print(f"Dataset {ds_uri} instantiated locally to {output}") | ||
|
||
if not output and not no_cp: | ||
raise ValueError("Please provide output directory for instantiation") | ||
|
||
client_config = client_config or self.client_config | ||
|
||
studio_client = StudioClient() | ||
|
||
try: | ||
remote_dataset_name, version = parse_dataset_uri(dataset_uri) | ||
remote_ds_name, version = parse_dataset_uri(remote_ds_uri) | ||
except Exception as e: | ||
raise DataChainError("Error when parsing dataset uri") from e | ||
|
||
dataset = None | ||
try: | ||
dataset = self.get_dataset(remote_dataset_name) | ||
except DatasetNotFoundError: | ||
# we will create new one if it doesn't exist | ||
pass | ||
|
||
if dataset and version and dataset.has_version(version): | ||
"""No need to communicate with Studio at all""" | ||
dataset_uri = create_dataset_uri(remote_dataset_name, version) | ||
print(f"Local copy of dataset {dataset_uri} already present") | ||
_instantiate_dataset() | ||
return | ||
remote_ds = self.get_remote_dataset(remote_ds_name) | ||
|
||
remote_dataset = self.get_remote_dataset(remote_dataset_name) | ||
# if version is not specified in uri, take the latest one | ||
if not version: | ||
version = remote_dataset.latest_version | ||
print(f"Version not specified, pulling the latest one (v{version})") | ||
# updating dataset uri with latest version | ||
dataset_uri = create_dataset_uri(remote_dataset_name, version) | ||
try: | ||
# if version is not specified in uri, take the latest one | ||
if not version: | ||
version = remote_ds.latest_version | ||
print(f"Version not specified, pulling the latest one (v{version})") | ||
# updating dataset uri with latest version | ||
remote_ds_uri = create_dataset_uri(remote_ds_name, version) | ||
remote_ds_version = remote_ds.get_version(version) | ||
except (DatasetVersionNotFoundError, StopIteration) as exc: | ||
raise DataChainError( | ||
f"Dataset {remote_ds_name} doesn't have version {version}" " on server" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [C] Should be a single string |
||
) from exc | ||
|
||
assert version | ||
local_ds_name = local_ds_name or remote_ds.name | ||
local_ds_version = local_ds_version or remote_ds_version.version | ||
local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version) | ||
|
||
if dataset and dataset.has_version(version): | ||
print(f"Local copy of dataset {dataset_uri} already present") | ||
_instantiate_dataset() | ||
try: | ||
# try to find existing dataset with the same uuid to avoid pulling again | ||
existing_ds = self.get_dataset_with_uuid(remote_ds_version.uuid) | ||
existing_ds_version = existing_ds.get_version_with_uuid( | ||
remote_ds_version.uuid | ||
) | ||
existing_ds_uri = create_dataset_uri( | ||
existing_ds.name, existing_ds_version.version | ||
) | ||
if existing_ds_uri == remote_ds_uri: | ||
print(f"Local copy of dataset {remote_ds_uri} already present") | ||
else: | ||
print( | ||
f"Local copy of dataset {remote_ds_uri} already present as" | ||
f" dataset {existing_ds_uri}" | ||
) | ||
_instantiate(existing_ds_uri) | ||
return | ||
except DatasetNotFoundError: | ||
pass | ||
|
||
try: | ||
remote_dataset_version = remote_dataset.get_version(version) | ||
except (DatasetVersionNotFoundError, StopIteration) as exc: | ||
raise DataChainError( | ||
f"Dataset {remote_dataset_name} doesn't have version {version}" | ||
" on server" | ||
) from exc | ||
local_dataset = self.get_dataset(local_ds_name) | ||
if local_dataset and local_dataset.has_version(local_ds_version): | ||
raise DataChainError( | ||
f"Local dataset {local_ds_uri} already exists, please" | ||
" choose different local dataset name or version" | ||
) | ||
except DatasetNotFoundError: | ||
pass | ||
|
||
stats_response = studio_client.dataset_stats(remote_dataset_name, version) | ||
stats_response = studio_client.dataset_stats( | ||
remote_ds_name, remote_ds_version.version | ||
) | ||
if not stats_response.ok: | ||
raise_remote_error(stats_response.message) | ||
dataset_stats = stats_response.data | ||
ds_stats = stats_response.data | ||
|
||
dataset_save_progress_bar = tqdm( | ||
desc=f"Saving dataset {dataset_uri} locally: ", | ||
desc=f"Saving dataset {remote_ds_uri} locally: ", | ||
unit=" rows", | ||
unit_scale=True, | ||
unit_divisor=1000, | ||
total=dataset_stats.num_objects, # type: ignore [union-attr] | ||
total=ds_stats.num_objects, # type: ignore [union-attr] | ||
) | ||
|
||
schema = DatasetRecord.parse_schema(remote_dataset_version.schema) | ||
schema = DatasetRecord.parse_schema(remote_ds_version.schema) | ||
|
||
columns = tuple( | ||
sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id" | ||
) | ||
# creating new dataset (version) locally | ||
dataset = self.create_dataset( | ||
remote_dataset_name, | ||
version, | ||
query_script=remote_dataset_version.query_script, | ||
local_ds = self.create_dataset( | ||
local_ds_name, | ||
local_ds_version, | ||
query_script=remote_ds_version.query_script, | ||
create_rows=True, | ||
columns=columns, | ||
feature_schema=remote_dataset_version.feature_schema, | ||
columns=tuple(sa.Column(n, t) for n, t in schema.items() if n != "sys__id"), | ||
feature_schema=remote_ds_version.feature_schema, | ||
validate_version=False, | ||
uuid=remote_dataset_version.uuid, | ||
uuid=remote_ds_version.uuid, | ||
) | ||
|
||
# asking remote to export dataset rows table to s3 and to return signed | ||
# urls of exported parts, which are in parquet format | ||
export_response = studio_client.export_dataset_table( | ||
remote_dataset_name, version | ||
remote_ds_name, remote_ds_version.version | ||
) | ||
if not export_response.ok: | ||
raise_remote_error(export_response.message) | ||
|
@@ -1439,8 +1459,10 @@ def _instantiate_dataset(): | |
rows_fetcher = DatasetRowsFetcher( | ||
metastore, | ||
warehouse, | ||
dataset.name, | ||
version, | ||
remote_ds_name, | ||
remote_ds_version.version, | ||
local_ds_name, | ||
local_ds_version, | ||
schema, | ||
) | ||
try: | ||
|
@@ -1452,23 +1474,23 @@ def _instantiate_dataset(): | |
dataset_save_progress_bar, | ||
) | ||
except: | ||
self.remove_dataset(dataset.name, version) | ||
self.remove_dataset(local_ds_name, local_ds_version) | ||
raise | ||
|
||
dataset = self.metastore.update_dataset_status( | ||
dataset, | ||
local_ds = self.metastore.update_dataset_status( | ||
local_ds, | ||
DatasetStatus.COMPLETE, | ||
version=version, | ||
error_message=remote_dataset.error_message, | ||
error_stack=remote_dataset.error_stack, | ||
script_output=remote_dataset.error_stack, | ||
version=local_ds_version, | ||
error_message=remote_ds.error_message, | ||
error_stack=remote_ds.error_stack, | ||
script_output=remote_ds.error_stack, | ||
) | ||
self.update_dataset_version_with_warehouse_info(dataset, version) | ||
self.update_dataset_version_with_warehouse_info(local_ds, local_ds_version) | ||
|
||
dataset_save_progress_bar.close() | ||
print(f"Dataset {dataset_uri} saved locally") | ||
print(f"Dataset {remote_ds_uri} saved locally") | ||
|
||
_instantiate_dataset() | ||
_instantiate(local_ds_uri) | ||
|
||
def clone( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -395,6 +395,18 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 | |
"--edatachain-file", | ||
help="Use a different filename for the resulting .edatachain file", | ||
) | ||
parse_pull.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [C] I still don't know if it is better to expose these or uuid to the user There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should let client work with name / version. |
||
"--local-name", | ||
action="store", | ||
default=None, | ||
help="Name of the local dataset", | ||
) | ||
parse_pull.add_argument( | ||
"--local-version", | ||
action="store", | ||
default=None, | ||
help="Version of the local dataset", | ||
) | ||
|
||
parse_edit_dataset = subp.add_parser( | ||
"edit-dataset", parents=[parent_parser], description="Edit dataset metadata" | ||
|
@@ -1121,6 +1133,8 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09 | |
catalog.pull_dataset( | ||
args.dataset, | ||
args.output, | ||
local_ds_name=args.local_name, | ||
local_ds_version=args.local_version, | ||
no_cp=args.no_cp, | ||
force=bool(args.force), | ||
edatachain=args.edatachain, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Q] Why not use
uuid
instead ofremote_ds_name
+remote_ds_version
&local_ds_name
+local_ds_version
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, there are other APIs that use
name
+version
as well, but I would leave that to followup issue not to bloat this one too much.In general, user does use explicit name and version when initiating pull, e.g
datachain pull ds://dogs@v3
but the use ofuuid
in the process after initial get dataset info would fix potential issue of someone renaming remote dataset right in the middle of pull process (not sue how likely is this to happen but it's still the potential risk)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of this needs to go into this PR but I wanted to mention it and this seems like an ok place.
I would advocate for hiding local/remote name/version from users as much as possible. (From experience) I think the concept will only confuse users. Only exposing it when we absolutely need to makes sense to me
For example:
If the user tries to pull
cat@v1
under the following circumstances we should throw an error stating there is a dataset with that name already and the uuid doesn't match. Ask them to rename the local dataset (as SaaS is the source of truth):I think we should update the output of
ls datasets
to have the same format as the above or at least give it the ability to show these discrepancies somehow.I am now second-guessing myself and wondering whether the local/remote name and version are that bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's how it works in this PR:
uuid
exists -> skip pulling and printing "Local copy of dataset cats@v1 already present". If it has different name locally it says "Local copy of dataset cats@v1 already present as dataset other_cats@v5"uuid
and there is already local dataset name + version present that holds completely different data we throw error with message "Local dataset cats@v1 already exists, please choose different local dataset name or version"I didn't want to mention
uuid
anywhere as I'm not sure yet if that's something we should bother user with, but maybe I could enhance message from 2) by mentioning it ... "EDIT: Changed it to "Local dataset cats@v1 already exists with different uuid, please choose different local dataset name or version"