Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/hfhub utils #2201

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

## [Unreleased](https://github.com/unit8co/darts/tree/master)

- Added utility functions for Huggingface Hub integration. Upload/download Darts TimeSeries and ForecastingModel instances. [#2201](https://github.com/unit8co/darts/pull/2201)
by [Ivelin Ivanov](https://github.com/ivelin).


[Full Changelog](https://github.com/unit8co/darts/compare/0.27.2...master)

### For users of the library:
Expand Down
104 changes: 104 additions & 0 deletions darts/utils/hfhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pandas as pd
from dotenv import load_dotenv
import os
import tempfile
from typing import Optional
from darts import TimeSeries
from darts.models.forecasting.forecasting_model import ForecastingModel
from huggingface_hub import snapshot_download, upload_folder, create_repo
from darts.logging import get_logger

logger = get_logger(__name__)

class HFHub:
"""
HuggingFace Hub integration using official HF API.
https://huggingface.co/docs/huggingface_hub/v0.20.3/en/guides/integrations
"""

def __init__(self, api_key: Optional[str] = None):
if api_key is None:
# load from .env file or OS vars if available
load_dotenv(override=True)
api_key = os.getenv("HF_TOKEN")
assert (
api_key is not None
), "Could not find HF_TOKEN in OS environment. Cannot interact with HF Hub."
self.HF_TOKEN = api_key

def upload_model(
self,
repo_id: str = None,
model: ForecastingModel = None,
private: Optional[bool] = True,
):
# Create repo if not existing yet and get the associated repo_id
create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True,
token=self.HF_TOKEN)

with tempfile.TemporaryDirectory() as tmpdirname:
# print("created temporary directory for model", tmpdirname)
model.save(path=f"{tmpdirname}/{model.model_name}")
upload_folder(repo_id=repo_id, folder_path=tmpdirname, token=self.HF_TOKEN)

def download_model(
self,
repo_id: str = None,
model_name: str = None,
model_class: object = None,
) -> ForecastingModel:
with tempfile.TemporaryDirectory() as tmpdirname:
logger.info(
"HFHub model files downloaded to local temp dir: ",
os.listdir(tmpdirname)
)
snapshot_download(
repo_id=repo_id, local_dir=tmpdirname, token=self.HF_TOKEN
)
model = model_class.load(path=f"{tmpdirname}/{model_name}")
return model

def upload_timeseries(
self,
repo_id: str = None,
series: TimeSeries = None,
series_name: str = None,
private: Optional[bool] = True,
):
# Create repo if not existing yet and get the associated repo_id
repo_info = create_repo(
repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True,
token=self.HF_TOKEN
)
# print(f"repo_info: ", repo_info)
df = series.pd_dataframe()
with tempfile.TemporaryDirectory() as tmpdirname:
df.to_parquet(path=f"{tmpdirname}/{series_name}.parquet")
upload_folder(
repo_id=repo_id,
repo_type="dataset",
folder_path=tmpdirname,
token=self.HF_TOKEN,
)

def download_timeseries(
self,
repo_id: str = None,
series_name: str = None,
) -> TimeSeries:
with tempfile.TemporaryDirectory() as tmpdirname:
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=tmpdirname,
token=self.HF_TOKEN,
)
logger.info(
"HFHub data files downloaded to local temp dir: ",
os.listdir(tmpdirname)
)
df = pd.read_parquet(
f"{tmpdirname}/{series_name}.parquet", engine="pyarrow"
)
ts = TimeSeries.from_dataframe(df)
return ts