Skip to content

Commit

Permalink
feat: memoization improvements (#81)
Browse files Browse the repository at this point in the history
### Summary of Changes

- shared memory
- lazy comparisons
- limited memory 
- value removal strategies
- use deterministic seed for hashing

Closes #44

Depends on Safe-DS/Library#609

---------

Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
WinPlay02 and megalinter-bot authored Apr 16, 2024
1 parent 6f820bf commit 6bc2288
Show file tree
Hide file tree
Showing 9 changed files with 1,390 additions and 210 deletions.
32 changes: 29 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python = "^3.11,<3.13"
safe-ds = ">=0.19,<0.21"
hypercorn = "^0.16.0"
quart = "^0.19.4"
psutil = "^5.9.8"

[tool.poetry.dev-dependencies]
pytest = "^8.1.1"
Expand Down
275 changes: 95 additions & 180 deletions src/safeds_runner/server/_memoization_map.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,23 @@
"""Module that contains the memoization logic and stats."""
"""Module that contains the memoization logic."""

import dataclasses
import inspect
import functools
import logging
import sys
import operator
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, TypeAlias
from typing import Any

MemoizationKey: TypeAlias = tuple[str, tuple[Any], tuple[Any]]
import psutil


@dataclass(frozen=True)
class MemoizationStats:
"""
Statistics calculated for every memoization call.
Parameters
----------
access_timestamps
Absolute timestamp since the unix epoch of the last access to the memoized value in nanoseconds
lookup_times
Duration the lookup of the value took in nanoseconds (key comparison + IPC)
computation_times
Duration the computation of the value took in nanoseconds
memory_sizes
Amount of memory the memoized value takes up in bytes
"""

access_timestamps: list[int] = dataclasses.field(default_factory=list)
lookup_times: list[int] = dataclasses.field(default_factory=list)
computation_times: list[int] = dataclasses.field(default_factory=list)
memory_sizes: list[int] = dataclasses.field(default_factory=list)

def update_on_hit(self, access_timestamp: int, lookup_time: int) -> None:
"""
Update the memoization stats on a cache hit.
Parameters
----------
access_timestamp
Timestamp when this value was last accessed
lookup_time
Duration the comparison took in nanoseconds
"""
self.access_timestamps.append(access_timestamp)
self.lookup_times.append(lookup_time)

def update_on_miss(self, access_timestamp: int, lookup_time: int, computation_time: int, memory_size: int) -> None:
"""
Update the memoization stats on a cache miss.
Parameters
----------
access_timestamp
Timestamp when this value was last accessed
lookup_time
Duration the comparison took in nanoseconds
computation_time
Duration the computation of the new value took in nanoseconds
memory_size
Memory the newly computed value takes up in bytes
"""
self.access_timestamps.append(access_timestamp)
self.lookup_times.append(lookup_time)
self.computation_times.append(computation_time)
self.memory_sizes.append(memory_size)

def __str__(self) -> str:
"""
Summarizes stats contained in this object.
Returns
-------
Summary of stats
"""
return ( # pragma: no cover
f"Last access: {self.access_timestamps}, computation time: {self.computation_times}, lookup time:"
f" {self.lookup_times}, memory size: {self.memory_sizes}"
)
from safeds_runner.server._memoization_stats import MemoizationStats
from safeds_runner.server._memoization_strategies import STAT_ORDER_PRIORITY
from safeds_runner.server._memoization_utils import (
MemoizationKey,
_create_memoization_key,
_get_size_of_value,
_unwrap_value_from_shared_memory,
_wrap_value_to_shared_memory,
)


class MemoizationMap:
Expand All @@ -106,6 +44,71 @@ def __init__(
"""
self._map_values: dict[MemoizationKey, Any] = map_values
self._map_stats: dict[str, MemoizationStats] = map_stats
# Set to half of physical available memory as a guess, in the future this could be set with an option
self.max_size: int | None = psutil.virtual_memory().total // 2
self.value_removal_strategy = STAT_ORDER_PRIORITY

def get_cache_size(self) -> int:
"""
Calculate the current size of the memoization cache.
Returns
-------
Amount of bytes, this cache occupies. This may be an estimate.
"""
return functools.reduce(
operator.add,
[functools.reduce(operator.add, stats.memory_sizes, 0) for stats in self._map_stats.values()],
0,
)

def ensure_capacity(self, needed_capacity: int) -> None:
"""
Ensure that the requested capacity is at least available, by freeing values from the cache.
If the needed capacity is larger than the max capacity, this function will not do anything to ensure further operation.
Parameters
----------
needed_capacity
Amount of free storage space requested, in bytes
"""
if self.max_size is None:
return
free_size = self.max_size - self.get_cache_size()
while free_size < needed_capacity < self.max_size:
self.remove_worst_element(needed_capacity - free_size)
free_size = self.max_size - self.get_cache_size()

def remove_worst_element(self, capacity_to_free: int) -> None:
"""
Remove the worst elements (most useless) from the cache, to free at least the provided amount of bytes.
Parameters
----------
capacity_to_free
Amount of bytes that should be additionally freed, after this function returns
"""
copied_stats = list(self._map_stats.copy().items())
# Sort functions to remove them from the cache in a specific order
copied_stats.sort(key=self.value_removal_strategy)
# Calculate which functions should be removed from the cache
bytes_freed = 0
functions_to_free = []
for function, stats in copied_stats:
if bytes_freed >= capacity_to_free:
break
function_sum_bytes = functools.reduce(operator.add, stats.memory_sizes, 0)
bytes_freed += function_sum_bytes
functions_to_free.append(function)
# Remove references to values, and let the gc handle the actual objects
for key in list(self._map_values.keys()):
for function_to_free in functions_to_free:
if key[0] == function_to_free:
del self._map_values[key]
# Remove stats, as content is gone
for function_to_free in functions_to_free:
del self._map_stats[function_to_free]

def memoized_function_call(
self,
Expand Down Expand Up @@ -140,11 +143,17 @@ def memoized_function_call(

# Lookup memoized value
lookup_time_start = time.perf_counter_ns()
key = self._create_memoization_key(function_name, parameters, hidden_parameters)
key = _create_memoization_key(function_name, parameters, hidden_parameters)
try:
memoized_value = self._lookup_value(key)
# Pickling may raise AttributeError, hashing may raise TypeError
except (AttributeError, TypeError):
except (AttributeError, TypeError) as exception:
# Fallback to executing the call to continue working, but inform user about this failure
logging.exception(
"Could not lookup value for function %s. Falling back to calling the function",
function_name,
exc_info=exception,
)
return function_callable(*parameters)
lookup_time = time.perf_counter_ns() - lookup_time_start

Expand All @@ -155,10 +164,15 @@ def memoized_function_call(

# Miss
computation_time_start = time.perf_counter_ns()
computed_value = self._compute_and_memoize_value(key, function_callable, parameters)
computed_value = function_callable(*parameters)
computation_time = time.perf_counter_ns() - computation_time_start
memory_size = _get_size_of_value(computed_value)

memoizable_value = _wrap_value_to_shared_memory(computed_value)
if self.max_size is not None:
self.ensure_capacity(_get_size_of_value(memoized_value))
self._map_values[key] = memoizable_value

self._update_stats_on_miss(
function_name,
access_timestamp,
Expand All @@ -178,30 +192,6 @@ def memoized_function_call(

return computed_value

def _create_memoization_key(
self,
function_name: str,
parameters: list[Any],
hidden_parameters: list[Any],
) -> MemoizationKey:
"""
Convert values provided to a memoized function call to a memoization key.
Parameters
----------
function_name
Fully qualified function name
parameters
List of parameters passed to the function
hidden_parameters
List of parameters not passed to the function
Returns
-------
A memoization key, which contains the lists converted to tuples
"""
return function_name, _make_hashable(parameters), _make_hashable(hidden_parameters)

def _lookup_value(self, key: MemoizationKey) -> Any | None:
"""
Lookup a potentially existing value from the memoization cache.
Expand All @@ -215,33 +205,8 @@ def _lookup_value(self, key: MemoizationKey) -> Any | None:
-------
The value corresponding to the provided memoization key, if any exists.
"""
return self._map_values.get(key)

def _compute_and_memoize_value(
self,
key: MemoizationKey,
function_callable: Callable,
parameters: list[Any],
) -> Any:
"""
Memoize a new function call and return computed the result.
Parameters
----------
key
Memoization Key
function_callable
Function that will be called
parameters
List of parameters passed to the function
Returns
-------
The newly computed value corresponding to the provided memoization key
"""
result = function_callable(*parameters)
self._map_values[key] = result
return result
looked_up_value = self._map_values.get(key)
return _unwrap_value_from_shared_memory(looked_up_value)

def _update_stats_on_hit(self, function_name: str, access_timestamp: int, lookup_time: int) -> None:
"""
Expand Down Expand Up @@ -293,53 +258,3 @@ def _update_stats_on_miss(

stats.update_on_miss(access_timestamp, lookup_time, computation_time, memory_size)
self._map_stats[function_name] = stats


def _make_hashable(value: Any) -> Any:
"""
Make a value hashable.
Parameters
----------
value:
Value to be converted.
Returns
-------
converted_value:
Converted value.
"""
if isinstance(value, dict):
return tuple((_make_hashable(key), _make_hashable(value)) for key, value in value.items())
elif isinstance(value, list):
return tuple(_make_hashable(element) for element in value)
elif callable(value):
# This is a band-aid solution to make callables serializable. Unfortunately, `getsource` returns more than just
# the source code for lambdas.
return inspect.getsource(value)
else:
return value


def _get_size_of_value(value: Any) -> int:
"""
Recursively calculate the memory usage of a given value.
Parameters
----------
value
Any value of which the memory usage should be calculated.
Returns
-------
Size of the provided value in bytes
"""
size_immediate = sys.getsizeof(value)
if isinstance(value, dict):
return (
sum(map(_get_size_of_value, value.keys())) + sum(map(_get_size_of_value, value.values())) + size_immediate
)
elif isinstance(value, frozenset | list | set | tuple):
return sum(map(_get_size_of_value, value)) + size_immediate
else:
return size_immediate
Loading

0 comments on commit 6bc2288

Please sign in to comment.