Skip to content

Commit

Permalink
upgrade to 0.2.0, remove deprecated mode switching (#15544)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf authored Aug 23, 2024
1 parent 68b90a3 commit 0bfc824
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Any, List, Literal, Optional
import warnings
from deprecated import deprecated

from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
Expand Down Expand Up @@ -129,9 +128,7 @@ def __init__(
self._is_hosted = base_url in KNOWN_URLS

if self._is_hosted and api_key == "NO_API_KEY_PROVIDED":
warnings.warn(
"An API key is required for hosted NIM. This will become an error in 0.2.0."
)
raise ValueError("An API key is required for hosted NIM.")

self._client = OpenAI(
api_key=api_key,
Expand Down Expand Up @@ -185,45 +182,6 @@ def available_models(self) -> List[Model]:
def class_name(cls) -> str:
return "NVIDIAEmbedding"

@deprecated(
version="0.1.2",
reason="Will be removed in 0.2. Construct with `base_url` instead.",
)
def mode(
self,
mode: Optional[Literal["nvidia", "nim"]] = "nvidia",
*,
base_url: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
) -> "NVIDIAEmbedding":
"""
Deprecated: use NVIDIAEmbedding(base_url="...") instead.
"""
if mode == "nim":
if not base_url:
raise ValueError("base_url is required for nim mode")
if mode == "nvidia":
api_key = get_from_param_or_env("api_key", api_key, "NVIDIA_API_KEY")
if not base_url:
# TODO: we should not assume unknown models are at the base url
base_url = MODEL_ENDPOINT_MAP.get(model or self.model, BASE_URL)

self._mode = mode
self._is_hosted = base_url in KNOWN_URLS
if base_url:
self._client.base_url = base_url
self._aclient.base_url = base_url
if model:
self.model = model
self._client.model = model
self._aclient.model = model
if api_key:
self._client.api_key = api_key
self._aclient.api_key = api_key

return self

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def get_api_key(instance: Any) -> str:


def test_create_default_url_without_api_key(masked_env_var: str) -> None:
with pytest.warns(UserWarning):
with pytest.raises(ValueError) as e:
Interface()
assert "API key is required" in str(e.value)


def test_create_unknown_url_without_api_key(masked_env_var: str) -> None:
Expand All @@ -38,16 +39,6 @@ def test_api_key_priority(masked_env_var: str) -> None:
del os.environ["NVIDIA_API_KEY"]


@pytest.mark.integration()
def test_missing_api_key_error(masked_env_var: str) -> None:
with pytest.warns(UserWarning):
client = Interface()
with pytest.raises(Exception) as exc_info:
client.get_query_embedding("Hello, world!")
message = str(exc_info.value)
assert "401" in message


@pytest.mark.integration()
def test_bogus_api_key_error(masked_env_var: str) -> None:
client = Interface(nvidia_api_key="BOGUS")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def test_base_url_valid_not_hosted(base_url: str, mock_v1_local_models2: None) -
],
)
def test_base_url_valid_hosted(base_url: str, mock_v1_local_models2: None) -> None:
Interface(base_url=base_url)
Interface(api_key="BOGUS", base_url=base_url)
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@


def test_embedding_class():
emb = NVIDIAEmbedding()
emb = NVIDIAEmbedding(api_key="BOGUS")
assert isinstance(emb, BaseEmbedding)


def test_nvidia_embedding_param_setting():
emb = NVIDIAEmbedding(
api_key="BOGUS",
model="test-model",
truncate="END",
timeout=20,
Expand All @@ -37,7 +38,7 @@ def test_nvidia_embedding_throws_on_batches_larger_than_259():


def test_nvidia_embedding_async():
emb = NVIDIAEmbedding()
emb = NVIDIAEmbedding(api_key="BOGUS")

assert inspect.iscoroutinefunction(emb._aget_query_embedding)
query_emb = emb._aget_query_embedding("hi")
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def mocked_route() -> respx.Route:
@pytest.mark.parametrize("truncate", ["END", "START", "NONE"])
def test_single_tuncate(mocked_route: respx.Route, method_name: str, truncate: str):
# call the method_name method
getattr(NVIDIAEmbedding(truncate=truncate), method_name)("nvidia")
getattr(NVIDIAEmbedding(api_key="BOGUS", truncate=truncate), method_name)("nvidia")

assert mocked_route.called
request = mocked_route.calls.last.request
Expand All @@ -36,7 +36,9 @@ async def test_asingle_tuncate(
mocked_route: respx.Route, method_name: str, truncate: str
):
# call the method_name method
await getattr(NVIDIAEmbedding(truncate=truncate), method_name)("nvidia")
await getattr(NVIDIAEmbedding(api_key="BOGUS", truncate=truncate), method_name)(
"nvidia"
)

assert mocked_route.called
request = mocked_route.calls.last.request
Expand All @@ -49,7 +51,9 @@ async def test_asingle_tuncate(
@pytest.mark.parametrize("truncate", ["END", "START", "NONE"])
def test_batch_tuncate(mocked_route: respx.Route, method_name: str, truncate: str):
# call the method_name method
getattr(NVIDIAEmbedding(truncate=truncate), method_name)(["nvidia"])
getattr(NVIDIAEmbedding(api_key="BOGUS", truncate=truncate), method_name)(
["nvidia"]
)

assert mocked_route.called
request = mocked_route.calls.last.request
Expand All @@ -65,7 +69,9 @@ async def test_abatch_tuncate(
mocked_route: respx.Route, method_name: str, truncate: str
):
# call the method_name method
await getattr(NVIDIAEmbedding(truncate=truncate), method_name)(["nvidia"])
await getattr(NVIDIAEmbedding(api_key="BOGUS", truncate=truncate), method_name)(
["nvidia"]
)

assert mocked_route.called
request = mocked_route.calls.last.request
Expand Down

0 comments on commit 0bfc824

Please sign in to comment.