diff --git a/backend/adapter_processor_v2/adapter_processor.py b/backend/adapter_processor_v2/adapter_processor.py index 062d26700..33fccec88 100644 --- a/backend/adapter_processor_v2/adapter_processor.py +++ b/backend/adapter_processor_v2/adapter_processor.py @@ -3,7 +3,7 @@ from typing import Any, Optional from account_v2.models import User -from adapter_processor_v2.constants import AdapterKeys +from adapter_processor_v2.constants import AdapterKeys, AllowedDomains from adapter_processor_v2.exceptions import ( InternalServiceError, InValidAdapterId, @@ -44,17 +44,25 @@ def get_json_schema(adapter_id: str) -> dict[str, Any]: return schema_details @staticmethod - def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]: + def get_all_supported_adapters(user_email: str, type: str) -> list[dict[Any, Any]]: """Function to return list of all supported adapters.""" supported_adapters = [] updated_adapters = [] updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( AdapterKeys.ADAPTER_TYPE, type ) + is_special_user = any( + identifier in user_email for identifier in AllowedDomains.list() + ) + for each_adapter in updated_adapters: + adapter_id = each_adapter.get(AdapterKeys.ID) + if not is_special_user and adapter_id.startswith("noOp"): + continue + supported_adapters.append( { - AdapterKeys.ID: each_adapter.get(AdapterKeys.ID), + AdapterKeys.ID: adapter_id, AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME), AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION), AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON), diff --git a/backend/adapter_processor_v2/constants.py b/backend/adapter_processor_v2/constants.py index 6557491b9..3b849a72b 100644 --- a/backend/adapter_processor_v2/constants.py +++ b/backend/adapter_processor_v2/constants.py @@ -1,3 +1,6 @@ +from enum import Enum + + class AdapterKeys: JSON_SCHEMA = "json_schema" ADAPTER_TYPE = "adapter_type" @@ -27,3 +30,12 @@ class AdapterKeys: ADAPTER_NAME = "adapter_name" ADAPTER_CREATED_BY = "created_by_email" ADAPTER_CONTEXT_WINDOW_SIZE = "context_window_size" + + +class AllowedDomains(Enum): + ZIPSTACK = "@zipstack.com" + UNSTRACT = "@unstract.com" + + @staticmethod + def list(): + return list(map(lambda c: c.value, AllowedDomains)) diff --git a/backend/adapter_processor_v2/views.py b/backend/adapter_processor_v2/views.py index 92eb5f2e7..299716d87 100644 --- a/backend/adapter_processor_v2/views.py +++ b/backend/adapter_processor_v2/views.py @@ -95,7 +95,7 @@ def list( or adapter_type == AdapterKeys.OCR ): json_schema = AdapterProcessor.get_all_supported_adapters( - type=adapter_type + type=adapter_type, user_email=request.user.email ) return Response(json_schema, status=status.HTTP_200_OK) else: