from typing import Dict, Union
import numpy
from numpy.typing import ArrayLike
from policyengine_core import periods
from policyengine_core.periods import Period
[docs]class InMemoryStorage:
"""
Low-level class responsible for storing and retrieving calculated vectors in memory
"""
_arrays: Dict[Period, ArrayLike]
is_eternal: bool
def __init__(self, is_eternal: bool):
self._arrays = {}
self.is_eternal = is_eternal
[docs] def clone(self) -> "InMemoryStorage":
clone = InMemoryStorage(self.is_eternal)
clone._arrays = {
period: array.copy() for period, array in self._arrays.items()
}
return clone
[docs] def get(self, period: Period, branch_name: str = "default") -> ArrayLike:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
values = self._arrays.get(f"{branch_name}:{period}")
if values is None:
return None
return values
[docs] def put(
self, value: ArrayLike, period: Period, branch_name: str = "default"
) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
self._arrays[f"{branch_name}:{period}"] = value
[docs] def delete(
self, period: Period = None, branch_name: str = "default"
) -> None:
if period is None:
self._arrays = {}
return
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
self._arrays = {
period_item: value
for period_item, value in self._arrays.items()
if not period.contains(periods.period(period_item.split(":")[1]))
}
[docs] def get_known_periods(self) -> list:
return list(
map(lambda x: periods.period(x.split(":")[1]), self._arrays.keys())
)
[docs] def get_known_branch_periods(self) -> list:
return [
(branch_name, periods.period(period))
for branch_name, period in map(
lambda x: x.split(":"), self._arrays.keys()
)
]
[docs] def get_memory_usage(self) -> dict:
if not self._arrays:
return dict(
nb_arrays=0,
total_nb_bytes=0,
cell_size=numpy.nan,
)
nb_arrays = len(self._arrays)
array = next(iter(self._arrays.values()))
return dict(
nb_arrays=nb_arrays,
total_nb_bytes=array.nbytes * nb_arrays,
cell_size=array.itemsize,
)