Skip to content

Commit

Permalink
added dummy huggingface token for pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
haeussma committed Feb 7, 2025
1 parent e44b7f4 commit d8603d5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/pyeed/embedding.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import gc
import os
from typing import Tuple, Union

import numpy as np
import torch
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from huggingface_hub import HfApi, HfFolder, login
from huggingface_hub import HfFolder, login
from numpy.typing import NDArray
from transformers import EsmModel, EsmTokenizer

from pyeed.dbconnect import DatabaseConnector

# Prompt for HuggingFace login credentials
# Ensure that you have your API key saved or use login() to authenticate interactively
hf_folder = HfFolder()
api = HfApi()
token = (
hf_folder.get_token() or login()
) # This will prompt for login if no token is saved

def get_hf_token() -> str:
"""Get or request Hugging Face token."""
if os.getenv("PYTEST_DISABLE_HF_LOGIN"): # Disable Hugging Face login in tests
return "dummy_token_for_tests"

hf_folder = HfFolder()
token = hf_folder.get_token()
if not token:
login() # Login returns None, get token after login
token = hf_folder.get_token()

if isinstance(token, str):
return token
else:
raise RuntimeError("Failed to get Hugging Face token")


def load_model_and_tokenizer(
Expand All @@ -37,6 +47,8 @@ def load_model_and_tokenizer(
Returns:
Tuple of (model, tokenizer, device)
"""
# Get token only when loading model
token = get_hf_token()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if this is an ESM-3 variant
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from typing import Generator

import pytest


@pytest.fixture(autouse=True)
def skip_hf_login() -> Generator[None, None, None]:
"""Skip Hugging Face login during tests."""
os.environ["PYTEST_DISABLE_HF_LOGIN"] = "1"
yield
if "PYTEST_DISABLE_HF_LOGIN" in os.environ:
del os.environ["PYTEST_DISABLE_HF_LOGIN"]

0 comments on commit d8603d5

Please sign in to comment.