Skip to content

Commit

Permalink
Stable cache keys in the case of nested dictionaries
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Oct 10, 2022
1 parent 9b8f255 commit 937db60
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
22 changes: 6 additions & 16 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional

import joblib
from diskcache import Cache
from google.protobuf.struct_pb2 import Struct
from joblib.hashing import NumpyHasher

from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

Expand All @@ -28,26 +27,17 @@ def _recursive_hash_placement(literal: Literal) -> Literal:
return literal


class ProtoJoblibHasher(NumpyHasher):
def save(self, obj):
if isinstance(obj, Struct):
obj = dict(
rewrite_rule="google.protobuf.struct_pb2.Struct",
cls=obj.__class__,
obj=dict(sorted(obj.fields.items())),
)
NumpyHasher.save(self, obj)


def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str:
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
for key, literal in input_literal_map.literals.items():
literal_map_overridden[key] = _recursive_hash_placement(literal)

# Generate a hash key of inputs with joblib
hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden)
return f"{task_name}-{cache_version}-{hashed_inputs}"
# Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the
# protobuf library.
hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True)
# Use joblib to hash the string representation of the literal into a fixed length string
return f"{task_name}-{cache_version}-{joblib.hash(hashed_inputs)}"


class LocalTaskCache(object):
Expand Down
47 changes: 46 additions & 1 deletion tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Dict, List

import pandas
import pandas as pd
import pytest
from dataclasses_json import dataclass_json
from pytest import fixture
from typing_extensions import Annotated
Expand Down Expand Up @@ -426,6 +428,14 @@ def test_stable_cache_key():
"b": "abcd",
"c": 0.12349,
"d": [1, 2, 3],
"e": {
"e_a": 11,
"e_b": list(range(1000)),
"e_c": {
"e_c_a": 12.34,
"e_c_b": "a string",
},
},
}
lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt)
lm = LiteralMap(
Expand All @@ -437,4 +447,39 @@ def test_stable_cache_key():
}
)
key = _calculate_cache_key("task_name_1", "31415", lm)
assert key == "task_name_1-31415-a291dc6fe0be387c1cfd67b4c6b78259"
assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049"


def calculate_cache_key_multiple_times(x, n=1000):
series = pd.Series(
[
_calculate_cache_key(
task_name="task_name",
cache_version="cache_version",
input_literal_map=LiteralMap(
literals={
"d": TypeEngine.to_literal(
ctx=FlyteContextManager.current_context(),
expected=TypeEngine.to_literal_type(Dict),
python_type=Dict,
python_val=x,
),
}
),
)
for _ in range(n)
]
).value_counts()
return series


@pytest.mark.parametrize(
"d",
[
dict(a=1, b=2, c=3),
dict(x=dict(a=1, b=2, c=3)),
dict(xs=[dict(a=1, b=2, c=3), dict(y=dict(a=10, b=20, c=30))]),
],
)
def test_cache_key_consistency(d):
assert len(calculate_cache_key_multiple_times(d)) == 1

0 comments on commit 937db60

Please sign in to comment.