import datetime
import inspect
import re
import textwrap
from typing import Callable, List, Type
import numpy
import sortedcontainers
from policyengine_core import periods, tools
from policyengine_core.entities import Entity
from policyengine_core.enums import Enum, EnumArray
from policyengine_core.periods import Period
from policyengine_core.holders import (
set_input_dispatch_by_period,
set_input_divide_by_period,
)
from policyengine_core.periods import DAY, ETERNITY
from . import config, helpers
class QuantityType:
STOCK = "stock"
FLOW = "flow"
class VariableStage:
INPUT = "input"
INTERMEDIATE = "intermediate"
OUTPUT = "output"
class VariableCategory:
TAX = "tax"
BENEFIT = "benefit"
INCOME = "income"
CONSUMPTION = "consumption"
WEALTH = "wealth"
DEMOGRAPHIC = "demographic"
[docs]class Variable:
"""
A `variable <https://openfisca.org/doc/key-concepts/variables.html>`_ of the legislation.
"""
name: str
"""Name of the variable"""
value_type: type
"""The value type of the variable. Possible value types in OpenFisca are ``int`` ``float`` ``bool`` ``str`` ``date`` and ``Enum``."""
entity: Entity
"""`Entity <https://openfisca.org/doc/key-concepts/person,_entities,_role.html>`_ the variable is defined for. For instance : ``Person``, ``Household``."""
definition_period: str
"""`Period <https://openfisca.org/doc/coding-the-legislation/35_periods.html>`_ the variable is defined for. Possible value: ``MONTH``, ``YEAR``, ``ETERNITY``."""
formulas: List[Callable]
"""Formulas used to calculate the variable"""
label: str
"""Description of the variable"""
reference: str
"""Legislative reference describing the variable."""
default_value: object
"""`Default value <https://openfisca.org/doc/key-concepts/variables.html#default-values>`_ of the variable."""
baseline_variable: str
"""If the variable has been introduced in a `reform <https://openfisca.org/doc/key-concepts/reforms.html>`_ to replace another variable, baseline_variable is the replaced variable."""
dtype: numpy.dtype
"""Numpy `dtype <https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.dtype.html>`_ used under the hood for the variable."""
end: datetime.date
"""`Date <https://openfisca.org/doc/coding-the-legislation/40_legislation_evolutions.html#variable-end>`_ when the variable disappears from the legislation."""
is_neutralized: bool
"""True if the variable is neutralized. Neutralized variables never use their formula, and only return their default values when calculated."""
json_type: str
"""JSON type corresponding to the variable."""
max_length: int
"""If the value type of the variable is ``str``, max length of the string allowed. ``None`` if there is no limit."""
possible_values: EnumArray
"""If the value type of the variable is ``Enum``, contains the values the variable can take."""
set_input: Callable
"""Function used to automatically process variable inputs defined for periods not matching the definition_period of the variable. See more on the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period>`_. Possible values are ``set_input_dispatch_by_period``, ``set_input_divide_by_period``, or nothing."""
unit: str
"""Free text field describing the unit of the variable. Only used as metadata."""
documentation: str
"""Free multilines text field describing the variable context and usage."""
quantity_type: str
"""Categorical attribute describing whether the variable is a stock or a flow."""
defined_for: str = None
"""The name of another variable, nonzero values of which are used to define the set of entities for which this variable is defined."""
metadata: dict = None
"""Free dictionary field used to store any metadata."""
module_name: str = None
"""The name of the module it is defined in."""
index_in_module: int = None
"""Index of the variable in the module it is defined in."""
adds: List[str] = None
"""List of variables that are added to the variable. Alternatively, can be a parameter name."""
subtracts: List[str] = None
"""List of variables that are subtracted from the variable. Alternatively, can be a parameter name."""
uprating: str = None
"""Name of a parameter used to uprate the variable."""
hidden_input: bool = False
"""Whether the variable is hidden from the input screen entirely on PolicyEngine."""
requires_computation_after: str = None
"""Name of a variable that must be computed before this variable."""
exhaustive_parameter_dependencies: List[str] = None
"""If these parameters (plus the dataset, branch and period) haven't changed, Core will use caching on this variable."""
min_value: (float, int) = None
"""Minimum value of the variable."""
max_value: (float, int) = None
"""Maximum value of the variable."""
def __init__(self, baseline_variable=None):
self.name = self.__class__.__name__
attr = {
name: value
for name, value in self.__class__.__dict__.items()
if not name.startswith("__")
}
# Allow inheritance for some properties
INHERITED_ALLOWED_PROPERTIES = (
"label",
"value_type",
"entity",
"definition_period",
)
for property_name in INHERITED_ALLOWED_PROPERTIES:
if not attr.get(property_name) and property_name in dir(
self.__class__
):
attr[property_name] = getattr(self, property_name)
self.baseline_variable = baseline_variable
self.value_type = self.set(
attr,
"value_type",
required=True,
allowed_values=config.VALUE_TYPES.keys(),
)
self.dtype = config.VALUE_TYPES[self.value_type]["dtype"]
self.json_type = config.VALUE_TYPES[self.value_type]["json_type"]
if self.value_type == Enum:
self.possible_values: Type[Enum] = self.set(
attr,
"possible_values",
required=True,
setter=self.set_possible_values,
)
if self.value_type == str:
self.max_length = self.set(attr, "max_length", allowed_type=int)
if self.max_length:
self.dtype = "|S{}".format(self.max_length)
if self.value_type == Enum:
self.default_value = self.set(
attr,
"default_value",
allowed_type=self.possible_values,
required=True,
)
else:
self.default_value = self.set(
attr,
"default_value",
allowed_type=self.value_type,
default=config.VALUE_TYPES[self.value_type].get("default"),
)
self.entity = self.set(
attr, "entity", required=True, setter=self.set_entity
)
self.definition_period = self.set(
attr,
"definition_period",
required=True,
allowed_values=(
periods.DAY,
periods.MONTH,
periods.YEAR,
periods.ETERNITY,
),
)
self.label = self.set(
attr, "label", allowed_type=str, setter=self.set_label
)
if self.label is None:
raise ValueError(
'Variable "{name}" has no label'.format(name=self.name)
)
self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end)
self.reference = self.set(attr, "reference", setter=self.set_reference)
self.cerfa_field = self.set(
attr, "cerfa_field", allowed_type=(str, dict)
)
self.unit = self.set(attr, "unit", allowed_type=str)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=(
QuantityType.STOCK
if (
self.value_type in (bool, int, Enum, str, datetime.date)
or self.unit == "/1"
)
else QuantityType.FLOW
),
)
self.documentation = self.set(
attr,
"documentation",
allowed_type=str,
setter=self.set_documentation,
)
self.set_input = self.set_set_input(
attr.pop(
"set_input",
(
set_input_dispatch_by_period
if self.quantity_type == QuantityType.STOCK
else set_input_divide_by_period
),
)
)
if self.definition_period in (DAY, ETERNITY):
self.set_input = None
self.calculate_output = self.set_calculate_output(
attr.pop("calculate_output", None)
)
self.is_period_size_independent = self.set(
attr,
"is_period_size_independent",
allowed_type=bool,
default=config.VALUE_TYPES[self.value_type][
"is_period_size_independent"
],
)
self.defined_for = self.set_defined_for(attr.pop("defined_for", None))
self.metadata = self.set(attr, "metadata", allowed_type=dict)
self.category = self.set(
attr,
"category",
allowed_type=str,
default=None,
)
self.module_name = self.set(
attr,
"module_name",
allowed_type=str,
default=None,
)
self.index_in_module = self.set(
attr,
"index_in_module",
allowed_type=int,
default=None,
)
self.adds = self.set(attr, "adds")
self.subtracts = self.set(attr, "subtracts")
self.uprating = self.set(attr, "uprating", allowed_type=str)
self.hidden_input = self.set(
attr, "hidden_input", allowed_type=bool, default=False
)
self.requires_computation_after = self.set(
attr, "requires_computation_after", allowed_type=str
)
self.exhaustive_parameter_dependencies = self.set(
attr, "exhaustive_parameter_dependencies"
)
if isinstance(self.exhaustive_parameter_dependencies, str):
self.exhaustive_parameter_dependencies = [
self.exhaustive_parameter_dependencies
]
self.min_value = self.set(
attr,
"min_value",
allowed_type=(float, int),
setter=self.set_min_value,
)
self.max_value = self.set(
attr,
"max_value",
allowed_type=(float, int),
setter=self.set_max_value,
)
formulas_attr, unexpected_attrs = helpers._partition(
attr,
lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX),
)
self.formulas = self.set_formulas(formulas_attr)
if unexpected_attrs:
raise ValueError(
'Unexpected attributes in definition of variable "{}": {!r}'.format(
self.name, ", ".join(sorted(unexpected_attrs.keys()))
)
)
self.is_neutralized = False
# ----- Setters used to build the variable ----- #
def set(
self,
attributes,
attribute_name,
required=False,
allowed_values=None,
allowed_type=None,
setter=None,
default=None,
):
value = attributes.pop(attribute_name, None)
if value is None and self.baseline_variable:
return getattr(self.baseline_variable, attribute_name)
if required and value is None:
raise ValueError(
"Missing attribute '{}' in definition of variable '{}'.".format(
attribute_name, self.name
)
)
if (
required
and allowed_values is not None
and value not in allowed_values
):
raise ValueError(
"Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'.".format(
value, attribute_name, self.name, allowed_values
)
)
if (
allowed_type is not None
and value is not None
and not isinstance(value, allowed_type)
):
if allowed_type == float and isinstance(value, int):
value = float(value)
else:
raise ValueError(
"Invalid value '{}' for attribute '{}' in variable '{}'. Must be of type '{}'.".format(
value, attribute_name, self.name, allowed_type
)
)
if setter is not None:
value = setter(value)
if value is None and default is not None:
return default
return value
def set_entity(self, entity):
if not isinstance(entity, Entity):
raise ValueError(
f"Invalid value '{entity}' for attribute 'entity' in variable '{self.name}'. Must be an instance of Entity."
)
return entity
def set_possible_values(self, possible_values):
if not issubclass(possible_values, Enum):
raise ValueError(
"Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}.".format(
possible_values, self.name, Enum
)
)
return possible_values
def set_label(self, label):
if label:
return label
def set_end(self, end):
if end:
try:
return datetime.datetime.strptime(end, "%Y-%m-%d").date()
except ValueError:
raise ValueError(
"Incorrect 'end' attribute format in '{}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {}".format(
self.name, end
)
)
def set_reference(self, reference):
if reference:
if isinstance(reference, str):
reference = [reference]
elif isinstance(reference, list):
pass
elif isinstance(reference, tuple):
reference = list(reference)
elif isinstance(reference, dict):
reference = [reference]
return reference
def set_documentation(self, documentation):
if documentation:
return textwrap.dedent(documentation)
def set_set_input(self, set_input):
if not set_input and self.baseline_variable:
return self.baseline_variable.set_input
return set_input
def set_calculate_output(self, calculate_output):
if not calculate_output and self.baseline_variable:
return self.baseline_variable.calculate_output
return calculate_output
def set_formulas(self, formulas_attr):
formulas = sortedcontainers.sorteddict.SortedDict()
for formula_name, formula in formulas_attr.items():
starting_date = self.parse_formula_name(formula_name)
if self.end is not None and starting_date > self.end:
raise ValueError(
'You declared that "{}" ends on "{}", but you wrote a formula to calculate it from "{}" ({}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.'.format(
self.name, self.end, starting_date, formula_name
)
)
formulas[str(starting_date)] = formula
# If the variable is reforming a baseline variable, keep the formulas from the latter when they are not overridden by new formulas.
if self.baseline_variable is not None:
first_reform_formula_date = (
formulas.peekitem(0)[0] if formulas else None
)
formulas.update(
{
baseline_start_date: baseline_formula
for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items()
if first_reform_formula_date is None
or baseline_start_date < first_reform_formula_date
}
)
return formulas
def set_defined_for(self, defined_for):
if isinstance(defined_for, Enum):
defined_for = defined_for.value
return defined_for
def set_min_value(self, min_value):
if min_value is not None:
if self.max_value is not None and min_value > self.max_value:
raise ValueError("min_value cannot be greater than max_value")
return min_value
def set_max_value(self, max_value):
if max_value is not None:
if self.min_value is not None and max_value < self.min_value:
raise ValueError("max_value cannot be smaller than min_value")
return max_value
# ----- Methods ----- #
[docs] @classmethod
def get_introspection_data(cls, tax_benefit_system):
"""
Get instrospection data about the code of the variable.
:returns: (comments, source file path, source code, start line number)
:rtype: tuple
"""
comments = inspect.getcomments(cls)
# Handle dynamically generated variable classes or Jupyter Notebooks, which have no source.
try:
absolute_file_path = inspect.getsourcefile(cls)
except TypeError:
source_file_path = None
else:
source_file_path = absolute_file_path.replace(
tax_benefit_system.get_package_metadata()["location"], ""
)
try:
source_lines, start_line_number = inspect.getsourcelines(cls)
source_code = textwrap.dedent("".join(source_lines))
except (IOError, TypeError):
source_code, start_line_number = None, None
return comments, source_file_path, source_code, start_line_number
def clone(self):
clone = self.__class__()
return clone
def check_set_value(self, value):
if self.value_type == Enum and isinstance(value, str):
try:
value = self.possible_values[value].index
except KeyError:
possible_values = [item.name for item in self.possible_values]
raise ValueError(
"'{}' is not a known value for '{}'. Possible values are ['{}'].".format(
value, self.name, "', '".join(possible_values)
)
)
if self.value_type in (float, int) and isinstance(value, str):
try:
value = tools.eval_expression(value)
except SyntaxError:
raise ValueError(
"I couldn't understand '{}' as a value for '{}'".format(
value, self.name
)
)
try:
value = numpy.array([value], dtype=self.dtype)[0]
except (TypeError, ValueError):
if self.value_type == datetime.date:
error_message = "Can't deal with date: '{}'.".format(value)
else:
error_message = "Can't deal with value: expected type {}, received '{}'.".format(
self.json_type, value
)
raise ValueError(error_message)
except OverflowError:
error_message = "Can't deal with value: '{}', it's too large for type '{}'.".format(
value, self.json_type
)
raise ValueError(error_message)
return value
def default_array(self, array_size):
array = numpy.empty(array_size, dtype=self.dtype)
if self.value_type == Enum:
array.fill(self.default_value.index)
return EnumArray(array, self.possible_values)
array.fill(self.default_value)
return array