Skip to content

Commit

Permalink
Add a conversion factor cache
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Dec 3, 2023
1 parent 98fbda4 commit 7aa995c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
1 change: 1 addition & 0 deletions pint/facets/context/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, registry_cache) -> None:
self.root_units = {}
self.dimensionality = registry_cache.dimensionality
self.parse_unit = registry_cache.parse_unit
self.conversion_factor = {}


class GenericContextRegistry(
Expand Down
56 changes: 45 additions & 11 deletions pint/facets/plain/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def __init__(self) -> None:
#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
self.parse_unit: dict[str, UnitsContainer] = {}

self.conversion_factor: dict[
tuple[UnitsContainer, UnitsContainer], Scalar | DimensionalityError
] = {}

def __eq__(self, other: Any):
if not isinstance(other, self.__class__):
return False
Expand All @@ -139,6 +143,7 @@ def __eq__(self, other: Any):
"root_units",
"dimensionality",
"parse_unit",
"conversion_factor",
)
return all(getattr(self, attr) == getattr(other, attr) for attr in attrs)

Expand Down Expand Up @@ -801,6 +806,43 @@ def get_root_units(

return f, self.Unit(units)

def _get_conversion_factor(
self, src: UnitsContainer, dst: UnitsContainer
) -> Scalar | DimensionalityError:
"""Get conversion factor in non-multiplicative units.
Parameters
----------
src
Source units
dst
Target units
Returns
-------
Conversion factor or DimensionalityError
"""
cache = self._cache.conversion_factor
try:
return cache[(src, dst)]
except KeyError:
pass

src_dim = self._get_dimensionality(src)
dst_dim = self._get_dimensionality(dst)

# If the source and destination dimensionality are different,
# then the conversion cannot be performed.
if src_dim != dst_dim:
return DimensionalityError(src, dst, src_dim, dst_dim)

# Here src and dst have only multiplicative units left. Thus we can
# convert with a factor.
factor, _ = self._get_root_units(src / dst)

cache[(src, dst)] = factor
return factor

def _get_root_units(
self, input_units: UnitsContainer, check_nonmult: bool = True
) -> tuple[Scalar, UnitsContainer]:
Expand Down Expand Up @@ -1015,18 +1057,10 @@ def _convert(
"""

if check_dimensionality:
src_dim = self._get_dimensionality(src)
dst_dim = self._get_dimensionality(dst)

# If the source and destination dimensionality are different,
# then the conversion cannot be performed.
if src_dim != dst_dim:
raise DimensionalityError(src, dst, src_dim, dst_dim)
factor = self._get_conversion_factor(src, dst)

# Here src and dst have only multiplicative units left. Thus we can
# convert with a factor.
factor, _ = self._get_root_units(src / dst)
if isinstance(factor, DimensionalityError):
raise factor

# factor is type float and if our magnitude is type Decimal then
# must first convert to Decimal before we can '*' the values
Expand Down

0 comments on commit 7aa995c

Please sign in to comment.