Source code for policyengine_core.tracers.trace_node
from __future__ import annotations
import dataclasses
import typing
if typing.TYPE_CHECKING:
    import numpy
    from policyengine_core.enums import EnumArray
    from policyengine_core.periods import Period
    Array = typing.Union[EnumArray, numpy.typing.ArrayLike]
    Time = typing.Union[float, int]
[docs]@dataclasses.dataclass
class TraceNode:
    name: str
    period: str
    branch_name: str = "default"
    parent: typing.Optional[TraceNode] = None
    children: typing.List[TraceNode] = dataclasses.field(default_factory=list)
    parameters: typing.List[TraceNode] = dataclasses.field(
        default_factory=list
    )
    value: typing.Optional[Array] = None
    start: float = 0
    end: float = 0
    def calculation_time(self, round_: bool = True) -> Time:
        result = self.end - self.start
        if round_:
            return self.round(result)
        return result
    def formula_time(self) -> float:
        children_calculation_time = sum(
            child.calculation_time(round_=False) for child in self.children
        )
        result = (
            +self.calculation_time(round_=False) - children_calculation_time
        )
        return self.round(result)
    def append_child(self, node: TraceNode) -> None:
        self.children.append(node)
    @staticmethod
    def round(time: Time) -> float:
        return float(f"{time:.4g}")  # Keep only 4 significant figures