diff --git a/api/environments/identities/models.py b/api/environments/identities/models.py index 9fe99df13df9..60f36f8a4464 100644 --- a/api/environments/identities/models.py +++ b/api/environments/identities/models.py @@ -4,12 +4,12 @@ from django.db import models from django.db.models import Prefetch, Q from django.utils import timezone -from flag_engine.identities.traits.types import TraitValue from flag_engine.segments.evaluator import evaluate_identity_in_segment from environments.identities.managers import IdentityManager from environments.identities.traits.models import Trait from environments.models import Environment +from environments.sdk.types import SDKTraitData from features.models import FeatureState from features.multivariate.models import MultivariateFeatureStateValue from segments.models import Segment @@ -196,7 +196,11 @@ def get_all_user_traits(self): def __str__(self): return "Account %s" % self.identifier - def generate_traits(self, trait_data_items, persist=False): + def generate_traits( + self, + trait_data_items: list[SDKTraitData], + persist: bool = False, + ) -> list[Trait]: """ Given a list of trait data items, validated by TraitSerializerFull, generate a list of TraitModel objects for the given identity. @@ -232,7 +236,7 @@ def generate_traits(self, trait_data_items, persist=False): def update_traits( self, - trait_data_items: list[dict[str, TraitValue]], + trait_data_items: list[SDKTraitData], ) -> list[Trait]: """ Given a list of traits, update any that already exist and create any new ones. diff --git a/api/environments/identities/serializers.py b/api/environments/identities/serializers.py index bd4b1ffaed1f..6c0d9df9229e 100644 --- a/api/environments/identities/serializers.py +++ b/api/environments/identities/serializers.py @@ -59,6 +59,7 @@ class _TraitSerializer(serializers.Serializer): help_text="Can be of type string, boolean, float or integer." ) + identifier = serializers.CharField() flags = serializers.ListField(child=SDKFeatureStateSerializer()) traits = serializers.ListSerializer(child=_TraitSerializer()) diff --git a/api/environments/identities/traits/fields.py b/api/environments/identities/traits/fields.py index fd5d3d335f9c..ceadc03b69c2 100644 --- a/api/environments/identities/traits/fields.py +++ b/api/environments/identities/traits/fields.py @@ -1,4 +1,5 @@ import logging +from typing import Any from rest_framework import serializers @@ -6,6 +7,7 @@ ACCEPTED_TRAIT_VALUE_TYPES, TRAIT_STRING_VALUE_MAX_LENGTH, ) +from environments.sdk.types import SDKTraitValueData from features.value_types import STRING logger = logging.getLogger(__name__) @@ -16,7 +18,7 @@ class TraitValueField(serializers.Field): Custom field to extract the type of the field on deserialization. """ - def to_internal_value(self, data): + def to_internal_value(self, data: Any) -> SDKTraitValueData: data_type = type(data).__name__ if data_type not in ACCEPTED_TRAIT_VALUE_TYPES: @@ -28,7 +30,7 @@ def to_internal_value(self, data): ) return {"type": data_type, "value": data} - def to_representation(self, value): + def to_representation(self, value: Any) -> Any: return_value = value.get("value") if isinstance(value, dict) else value if return_value is None: diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index f403b8001dc1..bbf9e2ce566d 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -309,6 +309,7 @@ def _get_all_feature_states_for_user_response( serializer = serializer_class( { "flags": all_feature_states, + "identifier": identity.identifier, "traits": identity.identity_traits.all(), }, context=self.get_serializer_context(), diff --git a/api/environments/sdk/serializers.py b/api/environments/sdk/serializers.py index 8ee5d254b94b..7029daa143b9 100644 --- a/api/environments/sdk/serializers.py +++ b/api/environments/sdk/serializers.py @@ -2,7 +2,6 @@ from collections import defaultdict from core.constants import BOOLEAN, FLOAT, INTEGER, STRING -from django.utils import timezone from rest_framework import serializers from environments.identities.models import Identity @@ -12,6 +11,12 @@ from environments.identities.traits.fields import TraitValueField from environments.identities.traits.models import Trait from environments.identities.traits.serializers import TraitSerializerBasic +from environments.sdk.services import ( + get_identified_transient_identity_and_traits, + get_persisted_identity_and_traits, + get_transient_identity_and_traits, +) +from environments.sdk.types import SDKTraitData from features.serializers import ( FeatureStateSerializerFull, SDKFeatureStateSerializer, @@ -125,7 +130,11 @@ def create(self, validated_data): class IdentifyWithTraitsSerializer( HideSensitiveFieldsSerializerMixin, serializers.Serializer ): - identifier = serializers.CharField(write_only=True, required=True) + identifier = serializers.CharField( + required=False, + allow_blank=True, + allow_null=True, + ) transient = serializers.BooleanField(write_only=True, default=False) traits = TraitSerializerBasic(required=False, many=True) flags = SDKFeatureStateSerializer(read_only=True, many=True) @@ -137,44 +146,46 @@ def save(self, **kwargs): Create the identity with the associated traits (optionally store traits if flag set on org) """ + identifier = self.validated_data.get("identifier") environment = self.context["environment"] - transient = self.validated_data["transient"] - trait_data_items = self.validated_data.get("traits", []) + sdk_trait_data: list[SDKTraitData] = self.validated_data.get("traits", []) - if transient: - identity = Identity( - created_date=timezone.now(), - identifier=self.validated_data["identifier"], + if not identifier: + # We have a fully transient identity that should never be persisted. + identity, traits = get_transient_identity_and_traits( environment=environment, + sdk_trait_data=sdk_trait_data, ) - trait_models = identity.generate_traits(trait_data_items, persist=False) - else: - identity, created = Identity.objects.get_or_create( - identifier=self.validated_data["identifier"], environment=environment + elif transient: + # Don't persist incoming data but load presently stored + # overrides and traits, if any. + identity, traits = get_identified_transient_identity_and_traits( + environment=environment, + identifier=identifier, + sdk_trait_data=sdk_trait_data, ) - if not created and environment.project.organisation.persist_trait_data: - # if this is an update and we're persisting traits, then we need to - # partially update any traits and return the full list - trait_models = identity.update_traits(trait_data_items) - else: - # generate traits for the identity and store them if configured to do so - trait_models = identity.generate_traits( - trait_data_items, - persist=environment.project.organisation.persist_trait_data, - ) + else: + # Persist the identity in accordance with individual trait transiency + # and persistence settings outside of request context. + identity, traits = get_persisted_identity_and_traits( + environment=environment, + identifier=identifier, + sdk_trait_data=sdk_trait_data, + ) all_feature_states = identity.get_all_feature_states( - traits=trait_models, + traits=traits, additional_filters=self.context.get("feature_states_additional_filters"), ) - identify_integrations(identity, all_feature_states, trait_models) + identify_integrations(identity, all_feature_states, traits) return { "identity": identity, - "traits": trait_models, + "identifier": identity.identifier, + "traits": traits, "flags": all_feature_states, } diff --git a/api/environments/sdk/services.py b/api/environments/sdk/services.py new file mode 100644 index 000000000000..500f35ac8b92 --- /dev/null +++ b/api/environments/sdk/services.py @@ -0,0 +1,120 @@ +import hashlib +import uuid +from itertools import chain +from operator import itemgetter +from typing import TypeAlias + +from django.utils import timezone + +from environments.identities.models import Identity +from environments.identities.traits.models import Trait +from environments.models import Environment +from environments.sdk.types import SDKTraitData + +IdentityAndTraits: TypeAlias = tuple[Identity, list[Trait]] + + +def get_transient_identity_and_traits( + environment: Environment, + sdk_trait_data: list[SDKTraitData], +) -> IdentityAndTraits: + """ + Get a transient `Identity` instance with a randomly generated identifier. + All traits are marked as transient. + """ + return ( + ( + identity := _get_transient_identity( + environment=environment, + identifier=get_transient_identifier(sdk_trait_data), + ) + ), + identity.generate_traits(_ensure_transient(sdk_trait_data), persist=False), + ) + + +def get_identified_transient_identity_and_traits( + environment: Environment, + identifier: str, + sdk_trait_data: list[SDKTraitData], +) -> IdentityAndTraits: + """ + Get a transient `Identity` instance. + If present in storage, it's a previously persisted identity with its traits, + combined with incoming traits provided to `sdk_trait_data` argument. + All traits constructed from `sdk_trait_data` are marked as transient. + """ + sdk_trait_data = _ensure_transient(sdk_trait_data) + if identity := Identity.objects.filter( + environment=environment, + identifier=identifier, + ).first(): + return identity, identity.update_traits(sdk_trait_data) + return ( + identity := _get_transient_identity( + environment=environment, + identifier=identifier, + ) + ), identity.generate_traits(sdk_trait_data, persist=False) + + +def get_persisted_identity_and_traits( + environment: Environment, + identifier: str, + sdk_trait_data: list[SDKTraitData], +) -> IdentityAndTraits: + """ + Retrieve a previously persisted `Identity` instance or persist a new one. + Traits are persisted based on the organisation-level setting or a + `"transient"` attribute provided with each individual trait. + """ + identity, created = Identity.objects.get_or_create( + environment=environment, + identifier=identifier, + ) + persist_trait_data = environment.project.organisation.persist_trait_data + if created: + return identity, identity.generate_traits( + sdk_trait_data, + persist=persist_trait_data, + ) + if persist_trait_data: + return identity, identity.update_traits(sdk_trait_data) + return identity, list( + { + trait.trait_key: trait + for trait in chain( + identity.identity_traits.all(), + identity.generate_traits(sdk_trait_data, persist=False), + ) + }.values() + ) + + +def get_transient_identifier(sdk_trait_data: list[SDKTraitData]) -> str: + if sdk_trait_data: + return hashlib.sha256( + "".join( + f'{trait["trait_key"]}{trait["trait_value"]["value"]}' + for trait in sorted(sdk_trait_data, key=itemgetter("trait_key")) + ).encode(), + usedforsecurity=False, + ).hexdigest() + return uuid.uuid4().hex + + +def _get_transient_identity( + environment: Environment, + identifier: str, +) -> Identity: + return Identity( + created_date=timezone.now(), + environment=environment, + identifier=identifier, + ) + + +def _ensure_transient(sdk_trait_data: list[SDKTraitData]) -> list[SDKTraitData]: + for sdk_trait_data_item in sdk_trait_data: + sdk_trait_data_item["transient"] = True + return sdk_trait_data diff --git a/api/environments/sdk/types.py b/api/environments/sdk/types.py new file mode 100644 index 000000000000..cf369f5a1907 --- /dev/null +++ b/api/environments/sdk/types.py @@ -0,0 +1,14 @@ +import typing + +from typing_extensions import NotRequired + + +class SDKTraitValueData(typing.TypedDict): + type: str + value: str + + +class SDKTraitData(typing.TypedDict): + trait_key: str + trait_value: SDKTraitValueData + transient: NotRequired[bool] diff --git a/api/tests/integration/environments/identities/test_integration_identities.py b/api/tests/integration/environments/identities/test_integration_identities.py index 445eedacfbb4..0537fb66f6ef 100644 --- a/api/tests/integration/environments/identities/test_integration_identities.py +++ b/api/tests/integration/environments/identities/test_integration_identities.py @@ -1,8 +1,11 @@ +import hashlib import json +from typing import Any, Generator from unittest import mock import pytest from django.urls import reverse +from pytest_lazyfixture import lazy_fixture from rest_framework import status from rest_framework.test import APIClient @@ -224,13 +227,64 @@ def test_get_feature_states_for_identity_only_makes_one_query_to_get_mv_feature_ assert len(second_identity_response_json["flags"]) == 3 -def test_get_feature_states_for_identity__transient_identity__segment_match_expected( +@pytest.fixture +def existing_identity_identifier_data( + identity_identifier: str, + identity: int, +) -> dict[str, Any]: + return {"identifier": identity_identifier} + + +@pytest.fixture +def transient_identifier( + segment_condition_property: str, + segment_condition_value: str, +) -> Generator[str, None, None]: + return hashlib.sha256( + f"avalue_a{segment_condition_property}{segment_condition_value}".encode() + ).hexdigest() + + +@pytest.mark.parametrize( + "transient_data", + [ + pytest.param({"transient": True}, id="with-transient-true"), + pytest.param({"transient": False}, id="with-transient-false"), + pytest.param({}, id="missing-transient"), + ], +) +@pytest.mark.parametrize( + "identifier_data,expected_identifier", + [ + pytest.param( + lazy_fixture("existing_identity_identifier_data"), + lazy_fixture("identity_identifier"), + id="existing-identifier", + ), + pytest.param({"identifier": "unseen"}, "unseen", id="new-identifier"), + pytest.param( + {"identifier": ""}, + lazy_fixture("transient_identifier"), + id="blank-identifier", + ), + pytest.param( + {"identifier": None}, + lazy_fixture("transient_identifier"), + id="null-identifier", + ), + pytest.param({}, lazy_fixture("transient_identifier"), id="missing-identifier"), + ], +) +def test_get_feature_states_for_identity__segment_match_expected( sdk_client: APIClient, feature: int, segment: int, segment_condition_property: str, segment_condition_value: str, segment_featurestate: int, + identifier_data: dict[str, Any], + transient_data: dict[str, Any], + expected_identifier: str, ) -> None: # Given url = reverse("api-v1:sdk-identities") @@ -242,14 +296,15 @@ def test_get_feature_states_for_identity__transient_identity__segment_match_expe url, data=json.dumps( { - "identifier": "unseen", + **identifier_data, + **transient_data, "traits": [ { "trait_key": segment_condition_property, "trait_value": segment_condition_value, - } + }, + {"trait_key": "a", "trait_value": "value_a"}, ], - "transient": True, } ), content_type="application/json", @@ -258,6 +313,7 @@ def test_get_feature_states_for_identity__transient_identity__segment_match_expe # Then assert response.status_code == status.HTTP_200_OK response_json = response.json() + assert response_json["identifier"] == expected_identifier assert ( flag_data := next( ( @@ -272,6 +328,29 @@ def test_get_feature_states_for_identity__transient_identity__segment_match_expe assert flag_data["feature_state_value"] == "segment override" +def test_get_feature_states_for_identity__empty_traits__random_identifier_expected( + sdk_client: APIClient, + environment: int, +) -> None: + # Given + url = reverse("api-v1:sdk-identities") + + # When + response_1 = sdk_client.post( + url, + data=json.dumps({"traits": []}), + content_type="application/json", + ) + response_2 = sdk_client.post( + url, + data=json.dumps({"traits": []}), + content_type="application/json", + ) + + # Then + assert response_1.json()["identifier"] != response_2.json()["identifier"] + + def test_get_feature_states_for_identity__transient_trait__segment_match_expected( sdk_client: APIClient, feature: int, diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index bf9e030058d0..b1e31cdb50b2 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -1,7 +1,9 @@ import json import urllib +from typing import Any from unittest import mock +import pytest from core.constants import FLAGSMITH_UPDATED_AT_HEADER, STRING from django.test import override_settings from django.urls import reverse @@ -1143,20 +1145,31 @@ def test_post_identities__server_key_only_feature__server_key_auth__return_expec assert response.json()["flags"] +@pytest.mark.parametrize( + "identity_data", + [ + pytest.param( + {"identifier": "transient", "transient": True}, + id="new-identifier-transient-true", + ), + pytest.param({"identifier": ""}, id="blank-identifier"), + pytest.param({"identifier": None}, id="null-identifier"), + pytest.param({}, id="missing_identifier"), + ], +) def test_post_identities__transient__no_persistence( environment: Environment, api_client: APIClient, + identity_data: dict[str, Any], ) -> None: # Given - identifier = "transient" trait_key = "trait_key" api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) url = reverse("api-v1:sdk-identities") data = { - "identifier": identifier, + **identity_data, "traits": [{"trait_key": trait_key, "trait_value": "bar"}], - "transient": True, } # When @@ -1166,10 +1179,74 @@ def test_post_identities__transient__no_persistence( # Then assert response.status_code == status.HTTP_200_OK - assert not Identity.objects.filter(identifier=identifier).exists() + assert not Identity.objects.exists() assert not Trait.objects.filter(trait_key=trait_key).exists() +@pytest.mark.parametrize( + "trait_transiency_data", + [ + pytest.param({"transient": True}, id="trait-transient-true"), + pytest.param({"transient": False}, id="trait-transient-false"), + pytest.param({}, id="trait-default"), + ], +) +def test_post_identities__existing__transient__no_persistence( + environment: Environment, + identity: Identity, + trait: Trait, + identity_featurestate: FeatureState, + api_client: APIClient, + trait_transiency_data: dict[str, Any], +) -> None: + # Given + feature_state_value = "identity override" + identity_featurestate.feature_state_value.string_value = feature_state_value + identity_featurestate.feature_state_value.save() + + trait_key = "trait_key" + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + url = reverse("api-v1:sdk-identities") + data = { + "identifier": identity.identifier, + "transient": True, + "traits": [ + {"trait_key": trait_key, "trait_value": "bar", **trait_transiency_data} + ], + } + + # When + response = api_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + response_json = response.json() + + # identity overrides are correctly loaded + assert response_json["flags"][0]["feature_state_value"] == feature_state_value + + # previously persisted traits not provided in the request + # are not marked as transient in the response + assert response_json["traits"][0]["trait_key"] == trait.trait_key + assert not response_json["traits"][0].get("transient") + + # every trait provided in the request for a transient identity + # is marked as transient + assert response_json["traits"][1]["trait_key"] == trait_key + assert response_json["traits"][1]["transient"] + + assert ( + persisted_trait := Trait.objects.filter( + identity=identity, trait_key=trait.trait_key + ).first() + ) + assert persisted_trait.trait_value == trait.trait_value + assert not Trait.objects.filter(identity=identity, trait_key=trait_key).exists() + + def test_post_identities__transient_traits__no_persistence( environment: Environment, api_client: APIClient,