Source code for policyengine_core.reforms.reform
from __future__ import annotations
import copy
from typing import Callable, Union, TYPE_CHECKING
from policyengine_core.parameters import ParameterNode, Parameter
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
if TYPE_CHECKING:
from policyengine_core.simulations import Simulation
from policyengine_core.periods import (
period as period_,
instant as instant_,
Period,
)
import requests
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
[docs]class Reform(TaxBenefitSystem):
"""A modified TaxBenefitSystem
All reforms must subclass `Reform` and implement a method `apply()`.
In this method, the reform can add or replace variables and call `modify_parameters` to modify the parameters of the legislation.
Example:
>>> from policyengine_core import reforms
>>> from policyengine_core.parameters import load_parameter_file
>>>
>>> def modify_my_parameters(parameters):
>>> # Add new parameters
>>> new_parameters = load_parameter_file(name='reform_name', file_path='path_to_yaml_file.yaml')
>>> parameters.add_child('reform_name', new_parameters)
>>>
>>> # Update a value
>>> parameters.taxes.some_tax.some_param.update(period=some_period, value=1000.0)
>>>
>>> return parameters
>>>
>>> class MyReform(reforms.Reform):
>>> def apply(self):
>>> self.add_variable(some_variable)
>>> self.update_variable(some_other_variable)
>>> self.modify_parameters(modifier_function = modify_my_parameters)
"""
name: str = None
"""The name of the reform. This is used to identify the reform in the UI."""
country_id: str = None
"""The country id of the reform. This is used to inform any calls to the PolicyEngine API."""
parameter_values: dict = None
"""The parameter values of the reform. This is used to inform any calls to the PolicyEngine API."""
simulation: "Simulation" = None
def __init__(self, baseline: TaxBenefitSystem):
"""
:param baseline: Baseline TaxBenefitSystem.
"""
super().__init__(baseline.entities)
self.baseline = baseline
self.parameters = baseline.parameters
self._parameters_at_instant_cache = (
baseline._parameters_at_instant_cache
)
self.variables = baseline.variables.copy()
self.decomposition_file_path = baseline.decomposition_file_path
self.key = self.__class__.__name__
if not hasattr(self, "apply"):
raise Exception(
"Reform {} must define an `apply` function".format(self.key)
)
self.apply()
def __getattr__(self, attribute):
return getattr(self.baseline, attribute)
@property
def full_key(self) -> str:
key = self.key
assert (
key is not None
), "key was not set for reform {} (name: {!r})".format(self, self.name)
if self.baseline is not None and hasattr(self.baseline, "key"):
baseline_full_key = self.baseline.full_key
key = ".".join([baseline_full_key, key])
return key
[docs] def modify_parameters(
self, modifier_function: Callable[[ParameterNode], ParameterNode]
) -> None:
"""Make modifications on the parameters of the legislation.
Call this function in `apply()` if the reform asks for legislation parameter modifications.
Args:
modifier_function: A function that takes a :obj:`.ParameterNode` and should return an object of the same type.
"""
baseline_parameters = self.baseline.parameters
baseline_parameters_copy = copy.deepcopy(baseline_parameters)
reform_parameters = modifier_function(baseline_parameters_copy)
if not isinstance(reform_parameters, ParameterNode):
return ValueError(
"modifier_function {} in module {} must return a ParameterNode".format(
modifier_function.__name__,
modifier_function.__module__,
)
)
self.parameters = reform_parameters
self._parameters_at_instant_cache = {}
[docs] @staticmethod
def from_dict(
parameter_values: dict,
country_id: str = None,
name: str = None,
) -> Reform:
"""Create a reform from a dictionary of parameters.
Args:
parameters: A dictionary of parameter -> { period -> value } pairs.
Returns:
A reform.
"""
class reform(Reform):
def apply(self):
for path, period_values in parameter_values.items():
parameter = self.parameters.get_child(path)
if not isinstance(period_values, dict):
parameter.update(
start="0000-01-01", value=period_values
)
else:
for period, value in period_values.items():
if "." in period:
start, stop = period.split(".")
start = instant_(start)
stop = instant_(stop)
parameter.update(
start=start, stop=stop, value=value
)
else:
parameter = parameter.update(
period=period, value=value
)
reform.country_id = country_id
reform.parameter_values = parameter_values
reform.name = name
return reform
[docs] @staticmethod
def from_api(
api_id: str,
country_id: str = None,
) -> Reform:
"""Create a reform from a dictionary of parameters.
Args:
parameters: A dictionary of parameter -> { period -> value } pairs.
Returns:
A reform.
"""
data = requests.get(
f"https://api.policyengine.org/{country_id}/policy/{api_id}"
).json()
parameter_values = data.get("result", {}).get("policy_json", {})
for path in parameter_values:
keys_to_remove = []
for start_stop_str in list(parameter_values[path].keys()):
start, stop = start_stop_str.split(".")
time_period = str(
period_("year:2000:100").intersection(
instant_(start), instant_(stop)
)
)
parameter_values[path][time_period] = parameter_values[path][
start_stop_str
]
keys_to_remove.append(start_stop_str)
for key in keys_to_remove:
del parameter_values[path][key]
return Reform.from_dict(
parameter_values,
country_id,
data.get("result", {}).get("label", None),
)
@classproperty
def api_id(self):
if self.country_id is None:
raise ValueError(
"`country_id` is not set. This is required to use the API."
)
if self.parameter_values is None:
raise ValueError(
"`parameter_values` is not set. This is required to use the API."
)
sanitised_parameter_values = {}
for path, period_values in self.parameter_values.items():
sanitised_period_values = {}
for period, value in period_values.items():
period = period_(period)
sanitised_period_values[f"{period.start}.{period.stop}"] = (
value
)
sanitised_parameter_values[path] = sanitised_period_values
response = requests.post(
f"https://api.policyengine.org/{self.country_id}/policy",
json={
"data": sanitised_parameter_values,
"name": self.name,
},
)
return response.json().get("result", {}).get("policy_id")
def set_parameter(
path: Union[Parameter, str],
value: float,
period: str = None,
start: str = None,
stop: str = None,
return_modifier=False,
) -> Reform:
if stop is not None:
stop = instant_(stop)
if start is not None:
start = instant_(start)
if isinstance(path, Parameter):
path = path.name
def modifier(parameters: ParameterNode):
node = parameters
for name in path.split("."):
try:
if "[" not in name:
node = node.children[name]
else:
try:
name, index = name.split("[")
index = int(index[:-1])
node = node.children[name].brackets[index]
except:
raise ValueError(
"Invalid bracket syntax (should be e.g. tax.brackets[3].rate"
)
except:
raise ValueError(
f"Could not find the parameter (failed at {name})."
)
node.update(period=period, value=value, start=start, stop=stop)
return parameters
if return_modifier:
return modifier
class reform(Reform):
def apply(self):
self.modify_parameters(modifier)
return reform