Source code for policyengine_core.simulations.simulation_builder

import copy
import typing
from typing import TYPE_CHECKING, Any, List

import dpath.util
import numpy as np
from numpy.typing import ArrayLike

from policyengine_core import periods
from policyengine_core.entities import Entity, Role
from policyengine_core.errors import (
    PeriodMismatchError,
    SituationParsingError,
    VariableNotFoundError,
)
from policyengine_core.periods import Period
from policyengine_core.populations import GroupPopulation, Population
from policyengine_core.simulations.simulation import Simulation
from policyengine_core.simulations.helpers import (
    check_type,
    _get_person_count,
    transform_to_strict_syntax,
)

if TYPE_CHECKING:
    from policyengine_core.taxbenefitsystems.tax_benefit_system import (
        TaxBenefitSystem,
    )

from policyengine_core.variables import Variable


[docs]class SimulationBuilder: def __init__(self): self.default_period: Period = ( None # Simulation period used for variables when no period is defined ) self.persons_plural: str = ( None # Plural name for person entity in current tax and benefits system ) # JSON input - Memory of known input values. Indexed by variable or axis name. self.input_buffer: typing.Dict[ Variable.name, typing.Dict[str(periods.period), np.array] ] = {} self.populations: typing.Dict[Entity.key, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. self.entity_counts: typing.Dict[Entity.plural, int] = {} # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} self.variable_entities: typing.Dict[Variable.name, Entity] = {} self.axes = [[]] self.has_axes = False self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = ( {} ) self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}
[docs] def build_from_dict( self, tax_benefit_system: "TaxBenefitSystem", input_dict: dict, simulation: Simulation = None, ) -> Simulation: """ Build a simulation from ``input_dict``. Optionally overwrites an existing simulation. This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. """ input_dict = self.explicit_singular_entities( tax_benefit_system, input_dict ) if any( key in tax_benefit_system.entities_plural() for key in input_dict.keys() ): simulation = self.build_from_entities( tax_benefit_system, input_dict, simulation ) else: simulation = self.build_from_variables( tax_benefit_system, input_dict, simulation ) simulation.input_variables = [ variable.name for variable in simulation.tax_benefit_system.variables.values() if len(simulation.get_holder(variable.name).get_known_periods()) > 0 ] return simulation
[docs] def build_from_entities( self, tax_benefit_system: "TaxBenefitSystem", input_dict: dict, simulation: Simulation = None, ) -> Simulation: """ Build a simulation from a Python dict ``input_dict`` fully specifying entities. Examples: >>> simulation_builder.build_from_entities({ 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, 'households': {'household': {'parents': ['Javier']}} }) """ input_dict = copy.deepcopy(input_dict) if simulation is None: simulation = Simulation( tax_benefit_system=tax_benefit_system, populations=tax_benefit_system.instantiate_entities(), ) # Register variables so get_variable_entity can find them for variable_name, _variable in tax_benefit_system.variables.items(): self.register_variable( variable_name, simulation.get_variable_population(variable_name).entity, ) check_type(input_dict, dict, ["error"]) axes = input_dict.pop("axes", None) unexpected_entities = [ entity for entity in input_dict if entity not in tax_benefit_system.entities_plural() ] if unexpected_entities: unexpected_entity = unexpected_entities[0] raise SituationParsingError( [unexpected_entity], "".join( [ "Some entities in the situation are not defined in the loaded tax and benefit system.", "These entities are not found: {0}.", "The defined entities are: {1}.", ] ).format( ", ".join(unexpected_entities), ", ".join(tax_benefit_system.entities_plural()), ), ) persons_json = input_dict.get( tax_benefit_system.person_entity.plural, None ) if not persons_json: raise SituationParsingError( [tax_benefit_system.person_entity.plural], "No {0} found. At least one {0} must be defined to run a simulation.".format( tax_benefit_system.person_entity.key ), ) persons_ids = self.add_person_entity( simulation.persons.entity, persons_json ) for entity_class in tax_benefit_system.group_entities: instances_json = input_dict.get(entity_class.plural) if instances_json is not None: self.add_group_entity( self.persons_plural, persons_ids, entity_class, instances_json, ) else: self.add_default_group_entity(persons_ids, entity_class) if axes: self.axes = axes self.expand_axes() try: self.finalize_variables_init(simulation.persons) except PeriodMismatchError as e: self.raise_period_mismatch( simulation.persons.entity, persons_json, e ) for entity_class in tax_benefit_system.group_entities: try: population = simulation.populations[entity_class.key] self.finalize_variables_init(population) except PeriodMismatchError as e: self.raise_period_mismatch( population.entity, instances_json, e ) return simulation
[docs] def build_from_variables( self, tax_benefit_system: "TaxBenefitSystem", input_dict: dict, simulation: Simulation = None, ) -> Simulation: """ Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. This method uses :any:`build_default_simulation` to infer an entity structure Example: >>> simulation_builder.build_from_variables( {'salary': {'2016-10': 12000}} ) """ count = _get_person_count(input_dict) simulation = self.build_default_simulation( tax_benefit_system, count, simulation ) for variable, value in input_dict.items(): if not isinstance(value, dict): if self.default_period is None: raise SituationParsingError( [variable], "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", ) simulation.set_input(variable, self.default_period, value) else: for period_str, dated_value in value.items(): simulation.set_input(variable, period_str, dated_value) return simulation
[docs] def build_default_simulation( self, tax_benefit_system: "TaxBenefitSystem", count: int = 1, simulation: Simulation = None, ) -> Simulation: """ Build a simulation where: - There are ``count`` persons - There are ``count`` instances of each group entity, containing one person - Every person has, in each entity, the first role """ if simulation is None: simulation = Simulation( tax_benefit_system=tax_benefit_system, populations=tax_benefit_system.instantiate_entities(), ) for population in simulation.populations.values(): population.count = count population.ids = np.array(range(count)) if not population.entity.is_person: population.members_entity_id = ( population.ids ) # Each person is its own group entity return simulation
[docs] def create_entities(self, tax_benefit_system: "TaxBenefitSystem") -> None: self.populations = tax_benefit_system.instantiate_entities()
[docs] def declare_person_entity( self, person_singular: str, persons_ids: typing.Iterable ) -> None: person_instance = self.populations[person_singular] person_instance.ids = np.array(persons_ids) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural
[docs] def declare_entity( self, entity_singular: str, entity_ids: typing.Iterable ) -> Population: entity_instance = self.populations[entity_singular] entity_instance.ids = np.array(entity_ids) entity_instance.count = len(entity_instance.ids) return entity_instance
[docs] def nb_persons(self, entity_singular: str, role: Role = None) -> ArrayLike: return self.populations[entity_singular].nb_persons(role=role)
[docs] def join_with_persons( self, group_population: GroupPopulation, persons_group_assignment: ArrayLike, roles: typing.Iterable[str], ) -> None: # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = np.unique( persons_group_assignment, return_inverse=True )[1] group_population.members_entity_id = np.argsort(group_population.ids)[ group_sorted_indices ] flattened_roles = group_population.entity.flattened_roles roles_array = np.array(roles) if np.issubdtype(roles_array.dtype, np.integer): group_population.members_role = np.array(flattened_roles)[ roles_array ] else: if len(flattened_roles) == 0: group_population.members_role = np.int64(0) else: group_population.members_role = np.select( [roles_array == role.key for role in flattened_roles], flattened_roles, )
[docs] def build(self, tax_benefit_system: "TaxBenefitSystem") -> Simulation: return Simulation( tax_benefit_system=tax_benefit_system, populations=self.populations )
[docs] def explicit_singular_entities( self, tax_benefit_system: "TaxBenefitSystem", input_dict: dict ) -> dict: """ Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut Example: >>> simulation_builder.explicit_singular_entities( {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} ) >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} """ axes = input_dict.get("axes") singular_keys = set(input_dict).intersection( tax_benefit_system.entities_by_singular() ) if not singular_keys: return input_dict result = { entity_id: entity_description for (entity_id, entity_description) in input_dict.items() if entity_id in tax_benefit_system.entities_plural() } # filter out the singular entities for singular in singular_keys: plural = tax_benefit_system.entities_by_singular()[singular].plural result[plural] = {singular: input_dict[singular]} if axes is not None: result["axes"] = axes return result
[docs] def add_person_entity( self, entity: Entity, instances_json: dict ) -> List[int]: """ Add the simulation's instances of the persons entity as described in ``instances_json``. """ check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural self.entity_ids[self.persons_plural] = entity_ids self.entity_counts[self.persons_plural] = len(entity_ids) for instance_id, instance_object in instances_json.items(): check_type(instance_object, dict, [entity.plural, instance_id]) self.init_variable_values( entity, instance_object, str(instance_id) ) return self.get_ids(entity.plural)
[docs] def add_default_group_entity( self, persons_ids: ArrayLike, entity: Entity ) -> None: persons_count = len(persons_ids) self.entity_ids[entity.plural] = [entity.key] self.entity_counts[entity.plural] = 1 self.memberships[entity.plural] = np.zeros( persons_count, dtype=np.int32 ) self.roles[entity.plural] = np.repeat( entity.flattened_roles[0], persons_count )
[docs] def add_group_entity( self, persons_plural: str, persons_ids: ArrayLike, entity: Entity, instances_json: dict, ) -> None: """ Add all instances of one of the model's entities as described in ``instances_json``. """ check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) persons_count = len(persons_ids) persons_to_allocate = set(persons_ids) self.memberships[entity.plural] = np.empty( persons_count, dtype=np.int32 ) self.roles[entity.plural] = np.empty(persons_count, dtype=object) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) for instance_id, instance_object in instances_json.items(): check_type(instance_object, dict, [entity.plural, instance_id]) variables_json = ( instance_object.copy() ) # Don't mutate function input roles_json = { role.plural or role.key: transform_to_strict_syntax( variables_json.pop(role.plural or role.key, []) ) for role in entity.roles } for role_id, role_definition in roles_json.items(): check_type( role_definition, list, [entity.plural, instance_id, role_id], ) for index, person_id in enumerate(role_definition): entity_plural = entity.plural self.check_persons_to_allocate( persons_plural, entity_plural, persons_ids, person_id, instance_id, role_id, persons_to_allocate, index, ) persons_to_allocate.discard(person_id) entity_index = entity_ids.index(instance_id) role_by_plural = { role.plural or role.key: role for role in entity.roles } for role_plural, persons_with_role in roles_json.items(): role = role_by_plural[role_plural] if role.max is not None and len(persons_with_role) > role.max: raise SituationParsingError( [entity.plural, instance_id, role_plural], f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.", ) for index_within_role, person_id in enumerate( persons_with_role ): person_index = persons_ids.index(person_id) self.memberships[entity.plural][ person_index ] = entity_index person_role = ( role.subroles[index_within_role] if role.subroles else role ) self.roles[entity.plural][person_index] = person_role self.init_variable_values(entity, variables_json, instance_id) if persons_to_allocate: entity_ids = entity_ids + list(persons_to_allocate) for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = ( entity_ids.index(person_id) ) self.roles[entity.plural][person_index] = ( entity.flattened_roles[0] ) # Adjust previously computed ids and counts self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) # Convert back to Python array self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[ entity.plural ].tolist()
[docs] def set_default_period(self, period_str: str) -> None: if period_str: self.default_period = str(periods.period(period_str))
[docs] def get_input(self, variable: str, period_str: str) -> Any: if variable not in self.input_buffer: self.input_buffer[variable] = {} return self.input_buffer[variable].get(period_str)
[docs] def check_persons_to_allocate( self, persons_plural, entity_plural, persons_ids, person_id, entity_id, role_id, persons_to_allocate, index, ): check_type( person_id, str, [entity_plural, entity_id, role_id, str(index)] ) if person_id not in persons_ids: raise SituationParsingError( [entity_plural, entity_id, role_id], "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( person_id, entity_id, role_id, persons_plural ), ) if person_id not in persons_to_allocate: raise SituationParsingError( [entity_plural, entity_id, role_id], "{} has been declared more than once in {}".format( person_id, entity_plural ), )
[docs] def init_variable_values(self, entity, instance_object, instance_id): for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] try: entity.check_variable_defined_for_entity(variable_name) except ( ValueError ) as e: # The variable is defined for another entity raise SituationParsingError(path_in_json, e.args[0]) except VariableNotFoundError as e: # The variable doesn't exist raise SituationParsingError(path_in_json, str(e), code=404) instance_index = self.get_ids(entity.plural).index(instance_id) if not isinstance(variable_values, dict): if self.default_period is None: raise SituationParsingError( path_in_json, "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", ) variable_values = {self.default_period: variable_values} for period_str, value in variable_values.items(): try: periods.period(period_str) except ValueError as e: raise SituationParsingError(path_in_json, e.args[0]) variable = entity.get_variable(variable_name) self.add_variable_value( entity, variable, instance_index, instance_id, period_str, value, )
[docs] def add_variable_value( self, entity, variable, instance_index, instance_id, period_str, value ): path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: return array = self.get_input(variable.name, str(period_str)) if array is None: array_size = self.get_count(entity.plural) array = variable.default_array(array_size) try: value = variable.check_set_value(value) except ValueError as error: raise SituationParsingError(path_in_json, *error.args) array[instance_index] = value self.input_buffer[variable.name][ str(periods.period(period_str)) ] = array
[docs] def finalize_variables_init(self, population): # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural if plural_key in self.entity_counts: population.count = self.get_count(plural_key) population.ids = self.get_ids(plural_key) if plural_key in self.memberships: population.members_entity_id = np.array( self.get_memberships(plural_key) ) population.members_role = np.array(self.get_roles(plural_key)) for variable_name in self.input_buffer.keys(): try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that continue buffer = self.input_buffer[variable_name] unsorted_periods = [ periods.period(period_str) for period_str in self.input_buffer[variable_name].keys() ] # We need to handle small periods first for set_input to work sorted_periods = sorted( unsorted_periods, key=periods.key_period_size ) for period_value in sorted_periods: values = buffer[str(period_value)] # Hack to replicate the values in the persons entity # when we have an axis along a group entity but not persons array = np.tile(values, population.count // len(values)) variable = holder.variable # TODO - this duplicates the check in Simulation.set_input, but # fixing that requires improving Simulation's handling of entities if (variable.end is None) or ( period_value.start.date <= variable.end ): holder.set_input(period_value, array)
[docs] def raise_period_mismatch(self, entity, json, e): # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( dpath.util.search( json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded=True, ), None, ) if culprit_path: path = [entity.plural] + culprit_path[0].split("/") else: path = [ entity.plural ] # Fallback: if we can't find the culprit, just set the error at the entities level raise SituationParsingError(path, e.message)
# Returns the total number of instances of this entity, including when there is replication along axes
[docs] def get_count(self, entity_name): return self.axes_entity_counts.get( entity_name, self.entity_counts[entity_name] )
# Returns the ids of instances of this entity, including when there is replication along axes
[docs] def get_ids(self, entity_name): return self.axes_entity_ids.get( entity_name, self.entity_ids[entity_name] )
# Returns the memberships of individuals in this entity, including when there is replication along axes
[docs] def get_memberships(self, entity_name): # Return empty array for the "persons" entity return self.axes_memberships.get( entity_name, self.memberships.get(entity_name, []) )
# Returns the roles of individuals in this entity, including when there is replication along axes
[docs] def get_roles(self, entity_name): # Return empty array for the "persons" entity return self.axes_roles.get( entity_name, self.roles.get(entity_name, []) )
[docs] def add_parallel_axis(self, axis): # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis)
[docs] def add_perpendicular_axis(self, axis): # This adds an axis perpendicular to all previous dimensions self.axes.append([axis])
[docs] def expand_axes(self): # This method should be idempotent & allow change in axes perpendicular_dimensions = self.axes cell_count = 1 for parallel_axes in perpendicular_dimensions: first_axis = parallel_axes[0] axis_count = first_axis["count"] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times for entity_name in self.entity_counts.keys(): # Adjust counts self.axes_entity_counts[entity_name] = ( self.get_count(entity_name) * cell_count ) # Adjust ids original_ids = self.get_ids(entity_name) * cell_count # indices should be [0, 0, 0, 1, 1, 1, 2, 2, 2, ...] where each number is repeated axis_entity_count times indices = np.repeat( np.arange(0, cell_count), self.entity_counts[entity_name] ) adjusted_ids = [ id + str(ix) for id, ix in zip(original_ids, indices) ] self.axes_entity_ids[entity_name] = adjusted_ids # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = np.repeat(original_roles, cell_count) self.axes_roles[entity_name] = adjusted_roles # Adjust memberships, for group entities only if entity_name != self.persons_plural: original_memberships = self.get_memberships(entity_name) # repeat membership, e.g. [1, 0] -> [1, 0, 1, 0, ...] repeated_memberships = np.tile( original_memberships, cell_count ) indices = ( np.repeat( np.arange(0, cell_count), len(original_memberships) ) * self.entity_counts[entity_name] ) adjusted_memberships = ( np.array(repeated_memberships) + indices ).tolist() self.axes_memberships[entity_name] = adjusted_memberships # Now generate input values along the specified axes # TODO - factor out the common logic here if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] axis_count: int = first_axis["count"] axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: axis_index = axis.get("index", 0) axis_period = axis.get("period", self.default_period) axis_name = axis["name"] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array( axis_count * axis_entity_step_size ) elif array.size == axis_entity_step_size: array = np.tile(array, axis_count) array[axis_index::axis_entity_step_size] = np.linspace( axis["min"], axis["max"], num=axis_count, ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: first_axes_count: typing.List[int] = ( parallel_axes[0]["count"] for parallel_axes in self.axes ) axes_linspaces = [ np.linspace(0, axis_count - 1, num=axis_count) for axis_count in first_axes_count ] axes_meshes = np.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): first_axis = parallel_axes[0] axis_count = first_axis["count"] axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: axis_index = axis.get("index", 0) axis_period = axis.get("period") or self.default_period axis_name = axis["name"] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array( cell_count * axis_entity_step_size ) elif array.size == axis_entity_step_size: array = np.tile(array, cell_count) array[axis_index::axis_entity_step_size] = axis[ "min" ] + mesh.reshape(cell_count) * ( axis["max"] - axis["min"] ) / ( axis_count - 1 ) self.input_buffer[axis_name][str(axis_period)] = array self.has_axes = True
[docs] def get_variable_entity(self, variable_name): return self.variable_entities[variable_name]
[docs] def register_variable(self, variable_name, entity): self.variable_entities[variable_name] = entity