Source code for policyengine_core.parameters.parameter_node

import copy
import os
import typing
from typing import Iterable, List, Type, Union

from policyengine_core import commons, parameters, tools
from policyengine_core.data_structures import Reference
from policyengine_core.periods.instant_ import Instant
from policyengine_core.tracers import TracingParameterNodeAtInstant

from .at_instant_like import AtInstantLike
from .parameter import Parameter
from .parameter_node_at_instant import ParameterNodeAtInstant
from .config import COMMON_KEYS, FILE_EXTENSIONS
from .helpers import (
    load_parameter_file,
    _compose_name,
    _validate_parameter,
    _parse_child,
    _load_yaml_file,
)

EXCLUDED_PARAMETER_CHILD_NAMES = ["reference", "__pycache__"]


[docs]class ParameterNode(AtInstantLike): """ A node in the legislation `parameter tree <https://openfisca.org/doc/coding-the-legislation/legislation_parameters.html>`_. """ _allowed_keys: typing.Optional[typing.Iterable[str]] = ( None # By default, no restriction on the keys ) parent: "ParameterNode" = None """The parent of the node, or None if the node is the root of the tree.""" def __init__( self, name: str = "", directory_path: str = None, data: dict = None, file_path: str = None, ): """ Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). :param str name: Name of the node, eg "taxes.some_tax". :param str directory_path: Directory containing YAML files describing the node. :param dict data: Object representing the parameter node. It usually has been extracted from a YAML file. :param str file_path: YAML file from which the `data` has been extracted from. Instantiate a ParameterNode from a dict: >>> node = ParameterNode('basic_income', data = { 'amount': { 'values': { "2015-01-01": {'value': 550}, "2016-01-01": {'value': 600} } }, 'min_age': { 'values': { "2015-01-01": {'value': 25}, "2016-01-01": {'value': 18} } }, }) Instantiate a ParameterNode from a directory containing YAML parameter files: >>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits') """ self.name: str = name self.children: typing.Dict[ str, typing.Union[ParameterNode, Parameter, parameters.ParameterScale], ] = {} self.description: str = None self.documentation: str = None self.file_path: str = None self.metadata: dict = {} self.trace: bool = False self.tracer = None self.branch_name = None self._at_instant_cache: typing.Dict[ Instant, ParameterNodeAtInstant ] = {} self.parent = None if directory_path: self.file_path = directory_path for child_name in os.listdir(directory_path): child_path = os.path.join(directory_path, child_name) if os.path.isfile(child_path): child_name, ext = os.path.splitext(child_name) # We ignore non-YAML files if child_name.upper() == "README": data = {} metadata = {} with open(child_path, "r") as f: # Get the header as the label (making sure to remove the leading hash), and the rest as the description lines = f.readlines() metadata["label"] = ( lines[0].replace("# ", "").strip() ) metadata["description"] = "".join( lines[1:] ).strip() self.metadata.update(metadata) if ext not in FILE_EXTENSIONS: continue if child_name == "index": data = _load_yaml_file(child_path) or {} _validate_parameter( self, data, allowed_keys=COMMON_KEYS ) self.description = data.get("description") self.documentation = data.get("documentation") self.metadata.update(data.get("metadata", {})) elif child_name not in EXCLUDED_PARAMETER_CHILD_NAMES: child_name_expanded = _compose_name(name, child_name) child = load_parameter_file( child_path, child_name_expanded ) self.add_child(child_name, child) elif os.path.isdir(child_path): child_name = os.path.basename(child_path) child_name_expanded = _compose_name(name, child_name) child = ParameterNode( child_name_expanded, directory_path=child_path ) self.add_child(child_name, child) else: self.file_path = file_path _validate_parameter( self, data, data_type=dict, allowed_keys=self._allowed_keys ) self.description = data.get("description") self.documentation = data.get("documentation") self.metadata.update(data.get("metadata", {})) for child_name, child in data.items(): if ( child_name in COMMON_KEYS or child_name in EXCLUDED_PARAMETER_CHILD_NAMES ): continue # do not treat reserved keys as subparameters. child_name = str(child_name) child_name_expanded = _compose_name(name, child_name) child = _parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) self.modified: bool = False
[docs] def merge(self, other: "ParameterNode") -> None: """ Merges another ParameterNode into the current node. In case of child name conflict, the other node child will replace the current node child. """ for child_name, child in other.children.items(): self.add_child(child_name, child)
[docs] def add_child(self, name: str, child: Union["ParameterNode", Parameter]): """ Add a new child to the node. :param name: Name of the child that must be used to access that child. Should not contain anything that could interfere with the operator `.` (dot). :param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`. """ if name in self.children: raise ValueError( "{} has already a child named {}".format(self.name, name) ) if not ( isinstance(child, ParameterNode) or isinstance(child, Parameter) or isinstance(child, parameters.ParameterScale) ): raise TypeError( "child must be of type ParameterNode, Parameter, or Scale. Instead got {}".format( type(child) ) ) self.children[name] = child setattr(self, name, child) child.parent = self
def __repr__(self) -> str: result = os.linesep.join( [ os.linesep.join(["{}:", "{}"]).format( name, tools.indent(repr(value)) ) for name, value in sorted(self.children.items()) ] ) return result
[docs] def get_descendants(self) -> Iterable[Union["ParameterNode", Parameter]]: """ Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode` """ for child in self.children.values(): yield child yield from child.get_descendants()
[docs] def clone(self) -> "ParameterNode": clone = commons.empty_clone(self) clone.__dict__ = self.__dict__.copy() clone.metadata = copy.deepcopy(self.metadata) clone.children = { key: child.clone() for key, child in self.children.items() } for child_key, child in clone.children.items(): setattr(clone, child_key, child) clone._at_instant_cache = {} return clone
def _get_at_instant(self, instant: Instant) -> ParameterNodeAtInstant: if instant in self._at_instant_cache: return self._at_instant_cache[instant] node_at_instant = ParameterNodeAtInstant(self.name, self, instant) if self.trace: at_instant = TracingParameterNodeAtInstant( node_at_instant, self.tracer, self.branch_name ) else: at_instant = node_at_instant self._at_instant_cache[instant] = at_instant return at_instant
[docs] def attach_to_parent(self, parent: "ParameterNode"): self.parent = parent
[docs] def clear_parent_cache(self): if self.parent is not None: self.parent.clear_parent_cache() self._at_instant_cache.clear()
[docs] def mark_as_modified(self): self.modified = True if self.parent is not None: self.parent.mark_as_modified()
[docs] def get_child(self, path: str) -> "ParameterNode": node = self for name in path.split("."): try: if "[" not in name: node = node.children[name] else: try: name, index = name.split("[") index = int(index[:-1]) node = node.children[name].brackets[index] except: raise ValueError( "Invalid bracket syntax (should be e.g. tax.brackets[3].rate" ) except: raise ValueError( f"Could not find the parameter (failed at {name})." ) return node