diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5108ef8..e2771ec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,7 +23,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install pytest pytest-cov mypy black isort torch + pip install pytest pytest-cov mypy black isort==5.1.4 torch pip install . - name: Test run: bash scripts/test.sh diff --git a/docs/requirements.txt b/docs/requirements.txt index 1f44dfd..5e219c4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ mkdocs-material-extensions==1.0 Pygments==2.6.1 pymdown-extensions==7.1 mkdocs-click +rich diff --git a/setup.py b/setup.py index 9e2c41a..410de44 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open('README.md') as readme_file: readme = readme_file.read().split('')[0] -requirements = ['Click>=7.0', 'click_didyoumean', 'hangar>=0.5.0'] +requirements = ['Click>=7.0', 'click_didyoumean', 'hangar>=0.5.0', 'rich'] setup( author="Sherin Thomas", diff --git a/stockroom/cli.py b/stockroom/cli.py index 9f13ce1..bca16db 100644 --- a/stockroom/cli.py +++ b/stockroom/cli.py @@ -4,9 +4,11 @@ import click from click_didyoumean import DYMGroup # type: ignore from hangar import Repository +from rich.progress import Progress from stockroom import __version__, external from stockroom.core import StockRoom from stockroom.keeper import init_repo +from stockroom.utils import print_columns_added @click.group( @@ -111,24 +113,32 @@ def import_data(dataset_name, download_dir): co = stock_obj.accessor importers = external.get_importers(dataset_name, download_dir) total_len = sum([len(importer) for importer in importers]) - with click.progressbar(label="Adding data to StockRoom", length=total_len) as bar: + splits_added = {} + + with Progress() as progress: + stock_add_bar = progress.add_task("Adding to Stockroom: ", total=total_len) for importer in importers: column_names = importer.column_names() dtypes = importer.dtypes() shapes = importer.shapes() + splits_added[importer.split] = (column_names, len(importer)) + for colname, dtype, shape in zip(column_names, dtypes, shapes): if colname not in co.keys(): # TODO: this assuming importer always return a numpy flat array co.add_ndarray_column(colname, dtype=dtype, shape=shape) + columns = [co[name] for name in column_names] with ExitStack() as stack: for col in columns: stack.enter_context(col) for i, data in enumerate(importer): - bar.update(1) + progress.advance(stock_add_bar) for col, dt in zip(columns, data): # TODO: use the keys from importer col[i] = dt + stock_obj.commit(f"Data from {dataset_name} added through stock import") stock_obj.close() - click.echo(f"The {dataset_name} dataset has been added to StockRoom") + click.echo(f"The {dataset_name} dataset has been added to StockRoom.") + print_columns_added(splits_added) diff --git a/stockroom/external/importer/torchvision_importers.py b/stockroom/external/importer/torchvision_importers.py index d4e1342..dde1bca 100644 --- a/stockroom/external/importer/torchvision_importers.py +++ b/stockroom/external/importer/torchvision_importers.py @@ -1,11 +1,12 @@ import os +import numpy as np +from stockroom.external.importer.base import BaseImporter + try: from torchvision import datasets # type: ignore except ModuleNotFoundError: pass -import numpy as np -from stockroom.external.importer.base import BaseImporter class TorchvisionCommon(BaseImporter): diff --git a/stockroom/utils.py b/stockroom/utils.py index 406cab9..0f22823 100644 --- a/stockroom/utils.py +++ b/stockroom/utils.py @@ -2,6 +2,39 @@ import types from pathlib import Path +from rich import box +from rich.console import Console +from rich.table import Table + +# init console object to be used throught. +console = Console() + + +def print_columns_added(splits_added: dict): + """ + Builds a Rich Table with the infor about the new columns created. + + Parameters + ---------- + splits_added : dict containing the column_names and length of each split + + Returns + ------- + Table + The final generated table ready to be displayed + + """ + table = Table(box=box.MINIMAL) + + table.add_column("Split [len]", no_wrap=True, justify="right", style="bold green") + table.add_column("Column Names") + + for split in splits_added: + column_names, num_samples = splits_added[split] + table.add_row(split + f" [{num_samples}]", ", ".join(column_names)) + + console.print(table) + def get_stock_root(path: Path) -> Path: """