Skip to content

Commit

Permalink
fix azure embeddings pydantic validator (#15603)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Aug 23, 2024
1 parent cfb72ab commit 1a26031
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional

import httpx
from llama_index.core.bridge.pydantic import Field, PrivateAttr, root_validator
from llama_index.core.bridge.pydantic import Field, PrivateAttr, model_validator
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
Expand All @@ -20,15 +20,21 @@

class AzureOpenAIEmbedding(OpenAIEmbedding):
azure_endpoint: Optional[str] = Field(
default=None, description="The Azure endpoint to use."
default=None, description="The Azure endpoint to use.", validate_default=True
)
azure_deployment: Optional[str] = Field(
default=None, description="The Azure deployment to use."
default=None, description="The Azure deployment to use.", validate_default=True
)

api_base: str = Field(default="", description="The base URL for Azure deployment.")
api_base: str = Field(
default="",
description="The base URL for Azure deployment.",
validate_default=True,
)
api_version: str = Field(
default="", description="The version for Azure OpenAI API."
default="",
description="The version for Azure OpenAI API.",
validate_default=True,
)

azure_ad_token_provider: AzureADTokenProvider = Field(
Expand Down Expand Up @@ -94,18 +100,19 @@ def __init__(
**kwargs,
)

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate necessary credentials are set."""
if (
values["api_base"] == "https://api.openai.com/v1"
and values["azure_endpoint"] is None
values.get("api_base") == "https://api.openai.com/v1"
and values.get("azure_endpoint") is None
):
raise ValueError(
"You must set OPENAI_API_BASE to your Azure endpoint. "
"It should look like https://YOUR_RESOURCE_NAME.openai.azure.com/"
)
if values["api_version"] is None:
if values.get("api_version") is None:
raise ValueError("You must set OPENAI_API_VERSION for Azure OpenAI.")

return values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-azure-openai"
readme = "README.md"
version = "0.2.1"
version = "0.2.2"

[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
llama-index-llms-azure-openai = "^0.2.0"
llama-index-embeddings-openai = "^0.2.0"
llama-index-embeddings-openai = "^0.2.3"
llama-index-core = "^0.11.0"

[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit 1a26031

Please sign in to comment.