Skip to content

Commit

Permalink
CLI improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan committed Aug 14, 2020
1 parent afd01e0 commit ccea282
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
24 changes: 18 additions & 6 deletions stockroom/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from contextlib import ExitStack
from pathlib import Path
from contextlib import ExitStack

import click
from click_didyoumean import DYMGroup # type: ignore
from hangar import Repository
from stockroom import __version__, external
from stockroom.core import StockRoom
from rich.progress import Progress

from stockroom.keeper import init_repo
from stockroom.core import StockRoom
from stockroom.utils import console, new_columns_table
from stockroom import __version__
from stockroom import external


@click.group(
Expand Down Expand Up @@ -111,24 +115,32 @@ def import_data(dataset_name, download_dir):
co = stock_obj.accessor
importers = external.get_importers(dataset_name, download_dir)
total_len = sum([len(importer) for importer in importers])
with click.progressbar(label="Adding data to StockRoom", length=total_len) as bar:
splits_added = {}

with Progress() as progress:
stock_add_bar = progress.add_task("Adding to Stockroom: ", total=total_len)
for importer in importers:
column_names = importer.column_names()
dtypes = importer.dtypes()
shapes = importer.shapes()
splits_added[importer.split] = (column_names, len(importer))

for colname, dtype, shape in zip(column_names, dtypes, shapes):
if colname not in co.keys():
# TODO: this assuming importer always return a numpy flat array
co.add_ndarray_column(colname, dtype=dtype, shape=shape)

columns = [co[name] for name in column_names]
with ExitStack() as stack:
for col in columns:
stack.enter_context(col)
for i, data in enumerate(importer):
bar.update(1)
progress.advance(stock_add_bar)
for col, dt in zip(columns, data):
# TODO: use the keys from importer
col[i] = dt

stock_obj.commit(f"Data from {dataset_name} added through stock import")
stock_obj.close()
click.echo(f"The {dataset_name} dataset has been added to StockRoom")
click.echo(f"The {dataset_name} dataset has been added to StockRoom.")
console.print(new_columns_table(splits_added))
36 changes: 35 additions & 1 deletion stockroom/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,41 @@
import importlib
import types
import importlib
from pathlib import Path

from rich.console import Console
from rich.table import Table
from rich import box


# init console object to be used throught.
console = Console()


def new_columns_table(splits_added: dict) -> Table:
"""
Builds a Rich Table with the infor about the new columns created.
Parameters
----------
splits_added : dict containing the column_names and length of each split
Returns
-------
Table
The final generated table ready to be displayed
"""
table = Table(box=box.MINIMAL)

table.add_column("Split [len]", no_wrap=True, justify="right", style="bold green")
table.add_column("Column Names")

for split in splits_added:
column_names, num_samples = splits_added[split]
table.add_row(split + f" [{num_samples}]", ", ".join(column_names))

return table


def get_stock_root(path: Path) -> Path:
"""
Expand Down

0 comments on commit ccea282

Please sign in to comment.