Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI improvements #32

Merged
merged 6 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install pytest pytest-cov mypy black isort torch
pip install pytest pytest-cov mypy black isort==5.1.4 torch
pip install .
- name: Test
run: bash scripts/test.sh
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ mkdocs-material-extensions==1.0
Pygments==2.6.1
pymdown-extensions==7.1
mkdocs-click
rich
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python

from setuptools import setup, find_packages
from setuptools import find_packages, setup

with open('README.md') as readme_file:
readme = readme_file.read().split('<!--- marker-for-pypi-to-trim --->')[0]

requirements = ['Click>=7.0', 'click_didyoumean', 'hangar>=0.5.0']
requirements = ['Click>=7.0', 'click_didyoumean', 'hangar>=0.5.0', 'rich']

setup(
author="Sherin Thomas",
Expand Down
16 changes: 13 additions & 3 deletions stockroom/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import click
from click_didyoumean import DYMGroup # type: ignore
from hangar import Repository
from rich.progress import Progress
from stockroom import __version__, external
from stockroom.core import StockRoom
from stockroom.keeper import init_repo
from stockroom.utils import print_columns_added


@click.group(
Expand Down Expand Up @@ -111,24 +113,32 @@ def import_data(dataset_name, download_dir):
co = stock_obj.accessor
importers = external.get_importers(dataset_name, download_dir)
total_len = sum([len(importer) for importer in importers])
with click.progressbar(label="Adding data to StockRoom", length=total_len) as bar:
splits_added = {}

with Progress() as progress:
stock_add_bar = progress.add_task("Adding to Stockroom: ", total=total_len)
for importer in importers:
column_names = importer.column_names()
dtypes = importer.dtypes()
shapes = importer.shapes()
splits_added[importer.split] = (column_names, len(importer))

for colname, dtype, shape in zip(column_names, dtypes, shapes):
if colname not in co.keys():
# TODO: this assuming importer always return a numpy flat array
co.add_ndarray_column(colname, dtype=dtype, shape=shape)

columns = [co[name] for name in column_names]
with ExitStack() as stack:
for col in columns:
stack.enter_context(col)
for i, data in enumerate(importer):
bar.update(1)
progress.advance(stock_add_bar)
for col, dt in zip(columns, data):
# TODO: use the keys from importer
col[i] = dt

stock_obj.commit(f"Data from {dataset_name} added through stock import")
stock_obj.close()
click.echo(f"The {dataset_name} dataset has been added to StockRoom")
click.echo(f"The {dataset_name} dataset has been added to StockRoom.")
print_columns_added(splits_added)
5 changes: 3 additions & 2 deletions stockroom/external/importer/torchvision_importers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

import numpy as np
from stockroom.external.importer.base import BaseImporter

try:
from torchvision import datasets # type: ignore
except ModuleNotFoundError:
pass
import numpy as np
from stockroom.external.importer.base import BaseImporter


class TorchvisionCommon(BaseImporter):
Expand Down
33 changes: 33 additions & 0 deletions stockroom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,39 @@
import types
from pathlib import Path

from rich import box
from rich.console import Console
from rich.table import Table

# init console object to be used throught.
console = Console()
hhsecond marked this conversation as resolved.
Show resolved Hide resolved


def print_columns_added(splits_added: dict):
"""
Builds a Rich Table with the infor about the new columns created.

Parameters
----------
splits_added : dict containing the column_names and length of each split

Returns
-------
Table
The final generated table ready to be displayed

"""
table = Table(box=box.MINIMAL)

table.add_column("Split [len]", no_wrap=True, justify="right", style="bold green")
table.add_column("Column Names")

for split in splits_added:
column_names, num_samples = splits_added[split]
table.add_row(split + f" [{num_samples}]", ", ".join(column_names))

console.print(table)


def get_stock_root(path: Path) -> Path:
"""
Expand Down