diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 48b6f9c7da..11cb3b926c 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,8 +1,7 @@ from typing import Optional +import joblib from diskcache import Cache -from google.protobuf.struct_pb2 import Struct -from joblib.hashing import NumpyHasher from flytekit.models.literals import Literal, LiteralCollection, LiteralMap @@ -28,26 +27,17 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal -class ProtoJoblibHasher(NumpyHasher): - def save(self, obj): - if isinstance(obj, Struct): - obj = dict( - rewrite_rule="google.protobuf.struct_pb2.Struct", - cls=obj.__class__, - obj=dict(sorted(obj.fields.items())), - ) - NumpyHasher.save(self, obj) - - def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): literal_map_overridden[key] = _recursive_hash_placement(literal) - # Generate a hash key of inputs with joblib - hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden) - return f"{task_name}-{cache_version}-{hashed_inputs}" + # Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the + # protobuf library. + hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True) + # Use joblib to hash the string representation of the literal into a fixed length string + return f"{task_name}-{cache_version}-{joblib.hash(hashed_inputs)}" class LocalTaskCache(object): diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 4648c87e2b..c3baebeace 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -4,6 +4,8 @@ from typing import Dict, List import pandas +import pandas as pd +import pytest from dataclasses_json import dataclass_json from pytest import fixture from typing_extensions import Annotated @@ -426,6 +428,14 @@ def test_stable_cache_key(): "b": "abcd", "c": 0.12349, "d": [1, 2, 3], + "e": { + "e_a": 11, + "e_b": list(range(1000)), + "e_c": { + "e_c_a": 12.34, + "e_c_b": "a string", + }, + }, } lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) lm = LiteralMap( @@ -437,4 +447,39 @@ def test_stable_cache_key(): } ) key = _calculate_cache_key("task_name_1", "31415", lm) - assert key == "task_name_1-31415-a291dc6fe0be387c1cfd67b4c6b78259" + assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049" + + +def calculate_cache_key_multiple_times(x, n=1000): + series = pd.Series( + [ + _calculate_cache_key( + task_name="task_name", + cache_version="cache_version", + input_literal_map=LiteralMap( + literals={ + "d": TypeEngine.to_literal( + ctx=FlyteContextManager.current_context(), + expected=TypeEngine.to_literal_type(Dict), + python_type=Dict, + python_val=x, + ), + } + ), + ) + for _ in range(n) + ] + ).value_counts() + return series + + +@pytest.mark.parametrize( + "d", + [ + dict(a=1, b=2, c=3), + dict(x=dict(a=1, b=2, c=3)), + dict(xs=[dict(a=1, b=2, c=3), dict(y=dict(a=10, b=20, c=30))]), + ], +) +def test_cache_key_consistency(d): + assert len(calculate_cache_key_multiple_times(d)) == 1