import os
import shutil
import numpy
from numpy.typing import ArrayLike
from policyengine_core import periods
from policyengine_core.enums import EnumArray
from policyengine_core.periods import Period
[docs]class OnDiskStorage:
"""
Low-level class responsible for storing and retrieving calculated vectors on disk
"""
def __init__(
self,
storage_dir: str,
is_eternal: bool = False,
preserve_storage_dir: bool = False,
):
self._files = {}
self._enums = {}
self.is_eternal = is_eternal
self.preserve_storage_dir = preserve_storage_dir
self.storage_dir = storage_dir
def _decode_file(self, file: str) -> ArrayLike:
enum = self._enums.get(file)
if enum is not None:
return EnumArray(numpy.load(file), enum)
else:
return numpy.load(file)
[docs] def get(self, period: Period, branch_name: str = "default") -> ArrayLike:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
values = self._files.get(f"{branch_name}_{period}")
if values is None:
return None
return self._decode_file(values)
[docs] def put(
self, value: ArrayLike, period: Period, branch_name: str = "default"
) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
filename = f"{branch_name}_{period}"
path = os.path.join(self.storage_dir, filename) + ".npy"
if isinstance(value, EnumArray):
self._enums[path] = value.possible_values
value = value.view(numpy.ndarray)
numpy.save(path, value)
self._files[filename] = path
[docs] def delete(
self, period: Period = None, branch_name: str = "default"
) -> None:
if period is None:
self._files = {}
return
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
if period is not None:
self._files = {
period_item: value
for period_item, value in self._files.items()
if not period_item == f"{branch_name}_{period}"
}
[docs] def get_known_periods(self) -> list:
return list(
[periods.period(x.split("_")[1]) for x in self._files.keys()]
)
[docs] def get_known_branch_periods(self) -> list:
return [
(branch_name, periods.period(period))
for branch_name, period in map(
lambda x: x.split("_"), self._files.keys()
)
]
[docs] def restore(self) -> None:
self._files = files = {}
# Restore self._files from content of storage_dir.
for filename in os.listdir(self.storage_dir):
if not filename.endswith(".npy"):
continue
path = os.path.join(self.storage_dir, filename)
filename_core = filename.rsplit(".", 1)[0]
files[filename_core] = path
def __del__(self) -> None:
if self.preserve_storage_dir:
return
shutil.rmtree(self.storage_dir) # Remove the holder temporary files
# If the simulation temporary directory is empty, remove it
parent_dir = os.path.abspath(os.path.join(self.storage_dir, os.pardir))
if not os.listdir(parent_dir):
shutil.rmtree(parent_dir)