Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make artifact downloads more robust #41

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ black==23.1.0
pre-commit==3.4.0
virtualenv==20.26.3
ipdb==0.13.11
tenacity==8.5.0
# https://gitlab-master.nvidia.com/clara-discovery/infra-bionemo
#infra-bionemo==0.3.1
11 changes: 11 additions & 0 deletions scripts/artifact_paths.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,24 @@ models:
symlink:
source: "../../geneformer-10M-240530-step-115430-wandb-4ij9ghox.nemo"
target: "singlecell/geneformer/geneformer-10M-240530.nemo"
md5sum: 375ebb9431419f4936fa3aa2bce6e7d6 # pragma: allowlist secret

geneformer:
# A QA model for geneformer with randomly initialized weights.
pbss: "s3://bionemo-ci/models/geneformer-qa.nemo"
symlink:
source: "../../geneformer-qa.nemo"
target: "singlecell/geneformer/geneformer-qa.nemo"
md5sum: 349a5ca34f19969dead3f85978e6c65d # pragma: allowlist secret

esm2nv_650m:
ngc: "nvidia/clara/esm2nv650m:1.0"
pbss: "s3://bionemo-ci/models/esm2nv_650M_converted.nemo"
symlink:
source: "../../esm2nv_650M_converted.nemo"
target: "protein/esm2nv/esm2nv_650M_converted.nemo"
md5sum: f1d926c4ed38ce16be962c79459c4abf # pragma: allowlist secret

# TODO (@fahadrgh) uncomment when we add hf -> nemo conversion test
# esm2nv_3b:
# ngc: "nvidia/clara/esm2nv3b:1.0"
Expand All @@ -49,11 +55,16 @@ data:
single_cell:
pbss: "s3://bionemo-ci/test-data/singlecell/singlecell-testdata-20240506.tar.gz"
relative_download_dir: "test_data/"
md5sum: dd6b0d791bf2b3301d9793a1d6663c75 # pragma: allowlist secret

single_cell_nemo1_geneformer_per_layer_outputs:
unpack: false
pbss: "s3://bionemo-ci/test-data/singlecell/nemo1-test-outputs-geneformer-qa.pt"
relative_download_dir: "test_data/"
md5sum: 5115c1b50998f62d26383d2cf87f597d # pragma: allowlist secret

single_cell_nemo1_geneformer_golden_vals:
unpack: false
pbss: "s3://bionemo-ci/test-data/singlecell/nemo1_geneformer_qa_test_golden_values.pt"
relative_download_dir: "test_data/"
md5sum: 3e9e95b4ff05dfe825d5b850c3de1ee3 # pragma: allowlist secret
47 changes: 44 additions & 3 deletions scripts/download_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Script to download pretrained models from NGC or PBSS."""

import argparse
import hashlib
import logging
import os
import sys
Expand All @@ -26,6 +27,7 @@

import yaml
from pydantic import BaseModel
from tenacity import retry, retry_if_exception_type, wait_exponential


ALL_KEYWORD = "all"
Expand All @@ -51,6 +53,7 @@ class ArtifactConfig(BaseModel):
extra_args: Optional[str] = None
untar_dir: Optional[str] = None
unpack: bool = True
md5sum: str


class Config(BaseModel):
Expand Down Expand Up @@ -228,9 +231,8 @@ def download_artifacts(
extra_args = conf[download_artifact].extra_args
command = f"{command} {extra_args}"

_, stderr, retcode = streamed_subprocess_call(command, stream_stdout)
if retcode != 0:
raise ValueError(f"Failed to download {download_artifact=}! {stderr=}")
execute_download(stream_stdout, conf, download_artifact, complete_download_dir, command, file_name)

if artifact_type == "data":
unpack: bool = getattr(conf[download_artifact], "unpack", True)
if unpack:
Expand All @@ -256,6 +258,29 @@ def download_artifacts(
raise ValueError(f"Failed to symlink {source_file=} to {target_file=}; {stderr=}")


@retry(wait=wait_exponential(multiplier=1, max=10), retry=retry_if_exception_type(ValueError))
def execute_download(
stream_stdout: bool,
conf: Dict[str, ArtifactConfig],
download_artifact: str,
complete_download_dir: Path,
command: List[str],
file_name: str,
) -> None:
"""Execute the download command and check the MD5 checksum of the downloaded file."""

_, stderr, retcode = streamed_subprocess_call(command, stream_stdout)
if retcode != 0:
raise ValueError(f"Failed to download {download_artifact=}! {stderr=}")

downloaded_md5sum = _md5_checksum(Path(complete_download_dir) / file_name)
if downloaded_md5sum != conf[download_artifact].md5sum:
raise ValueError(
f"MD5 checksum mismatch for {download_artifact=}! Expected "
f"{conf[download_artifact].md5sum}, got {downloaded_md5sum}"
)


def load_config(config_file: Path = DATA_SOURCE_CONFIG) -> Config:
"""Loads the artifacts file into a dictionary.

Expand Down Expand Up @@ -333,5 +358,21 @@ def main():
logging.warning("No models or data were selected to download.")


def _md5_checksum(file_path: Path) -> str:
"""Calculate the MD5 checksum of a file.

Args:
file_path (Path): The path to the file to checksum.

Returns:
str: The MD5 checksum of the file.
"""
md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat trick :) TIL

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github copilot 😄 , I can't take too much credit

md5.update(chunk)
return md5.hexdigest()


if __name__ == "__main__":
main()