Source code for policyengine_core.tracers.computation_log

from __future__ import annotations

import typing
from typing import List, Optional, Union

import numpy

from policyengine_core.enums import EnumArray

from .. import tracers

if typing.TYPE_CHECKING:
    from numpy.typing import ArrayLike

    Array = Union[EnumArray, ArrayLike]


[docs]class ComputationLog: _full_tracer: tracers.FullTracer def __init__(self, full_tracer: tracers.FullTracer) -> None: self._full_tracer = full_tracer def display( self, value: Optional[Array], ) -> str: if isinstance(value, EnumArray): value = value.decode_to_str() return numpy.array2string(value, max_line_width=float("inf")) def lines( self, aggregate: bool = False, max_depth: Optional[int] = None, ) -> List[str]: depth = 1 lines_by_tree = [ self._get_node_log(node, depth, aggregate, max_depth) for node in self._full_tracer.trees ] return self._flatten(lines_by_tree)
[docs] def print_log(self, aggregate=False, max_depth=100) -> None: """ Print the computation log of a simulation. If ``aggregate`` is ``False`` (default), print the value of each computed vector. If ``aggregate`` is ``True``, only print the minimum, maximum, and average value of each computed vector. This mode is more suited for simulations on a large population. If ``max_depth`` is ``None`` (default), print the entire computation. If ``max_depth`` is set, for example to ``3``, only print computed vectors up to a depth of ``max_depth``. """ for line in self.lines(aggregate, max_depth): print(line)
def _get_node_log( self, node: tracers.TraceNode, depth: int, aggregate: bool, max_depth: Optional[int], ) -> List[str]: if max_depth is not None and depth > max_depth: return [] node_log = [self._print_line(depth, node, aggregate, max_depth)] children_logs = [ self._get_node_log(child, depth + 1, aggregate, max_depth) for child in node.children ] return node_log + self._flatten(children_logs) def _print_line( self, depth: int, node: tracers.TraceNode, aggregate: bool, max_depth: Optional[int], ) -> str: indent = " " * depth value = node.value if value is None: formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" elif aggregate: try: formatted_value = str( { "avg": numpy.mean(value), "max": numpy.max(value), "min": numpy.min(value), } ) except TypeError: formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" else: formatted_value = self.display(value) return f"{indent}{node.name}<{node.period}, ({node.branch_name})> = {formatted_value}" def _flatten( self, lists: List[List[str]], ) -> List[str]: return [item for list_ in lists for item in list_]