import datetime
import inspect
import re
import textwrap
from typing import Callable, List, Type
import numpy
import sortedcontainers
from policyengine_core import periods, tools
from policyengine_core.entities import Entity
from policyengine_core.enums import Enum, EnumArray
from policyengine_core.periods import Period
from policyengine_core.holders import (
    set_input_dispatch_by_period,
    set_input_divide_by_period,
)
from policyengine_core.periods import DAY, ETERNITY
from . import config, helpers
class QuantityType:
    STOCK = "stock"
    FLOW = "flow"
class VariableStage:
    INPUT = "input"
    INTERMEDIATE = "intermediate"
    OUTPUT = "output"
class VariableCategory:
    TAX = "tax"
    BENEFIT = "benefit"
    INCOME = "income"
    CONSUMPTION = "consumption"
    WEALTH = "wealth"
    DEMOGRAPHIC = "demographic"
[docs]class Variable:
    """
    A `variable <https://openfisca.org/doc/key-concepts/variables.html>`_ of the legislation.
    """
    name: str
    """Name of the variable"""
    value_type: type
    """The value type of the variable. Possible value types in OpenFisca are ``int`` ``float`` ``bool`` ``str`` ``date`` and ``Enum``."""
    entity: Entity
    """`Entity <https://openfisca.org/doc/key-concepts/person,_entities,_role.html>`_ the variable is defined for. For instance : ``Person``, ``Household``."""
    definition_period: str
    """`Period <https://openfisca.org/doc/coding-the-legislation/35_periods.html>`_ the variable is defined for. Possible value: ``MONTH``, ``YEAR``, ``ETERNITY``."""
    formulas: List[Callable]
    """Formulas used to calculate the variable"""
    label: str
    """Description of the variable"""
    reference: str
    """Legislative reference describing the variable."""
    default_value: object
    """`Default value <https://openfisca.org/doc/key-concepts/variables.html#default-values>`_ of the variable."""
    baseline_variable: str
    """If the variable has been introduced in a `reform <https://openfisca.org/doc/key-concepts/reforms.html>`_ to replace another variable, baseline_variable is the replaced variable."""
    dtype: numpy.dtype
    """Numpy `dtype <https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.dtype.html>`_ used under the hood for the variable."""
    end: datetime.date
    """`Date <https://openfisca.org/doc/coding-the-legislation/40_legislation_evolutions.html#variable-end>`_  when the variable disappears from the legislation."""
    is_neutralized: bool
    """True if the variable is neutralized. Neutralized variables never use their formula, and only return their default values when calculated."""
    json_type: str
    """JSON type corresponding to the variable."""
    max_length: int
    """If the value type of the variable is ``str``, max length of the string allowed. ``None`` if there is no limit."""
    possible_values: EnumArray
    """If the value type of the variable is ``Enum``, contains the values the variable can take."""
    set_input: Callable
    """Function used to automatically process variable inputs defined for periods not matching the definition_period of the variable. See more on the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period>`_. Possible values are ``set_input_dispatch_by_period``, ``set_input_divide_by_period``, or nothing."""
    unit: str
    """Free text field describing the unit of the variable. Only used as metadata."""
    documentation: str
    """Free multilines text field describing the variable context and usage."""
    quantity_type: str
    """Categorical attribute describing whether the variable is a stock or a flow."""
    defined_for: str = None
    """The name of another variable, nonzero values of which are used to define the set of entities for which this variable is defined."""
    metadata: dict = None
    """Free dictionary field used to store any metadata."""
    module_name: str = None
    """The name of the module it is defined in."""
    index_in_module: int = None
    """Index of the variable in the module it is defined in."""
    adds: List[str] = None
    """List of variables that are added to the variable. Alternatively, can be a parameter name."""
    subtracts: List[str] = None
    """List of variables that are subtracted from the variable. Alternatively, can be a parameter name."""
    uprating: str = None
    """Name of a parameter used to uprate the variable."""
    hidden_input: bool = False
    """Whether the variable is hidden from the input screen entirely on PolicyEngine."""
    requires_computation_after: str = None
    """Name of a variable that must be computed before this variable."""
    exhaustive_parameter_dependencies: List[str] = None
    """If these parameters (plus the dataset, branch and period) haven't changed, Core will use caching on this variable."""
    min_value: (float, int) = None
    """Minimum value of the variable."""
    max_value: (float, int) = None
    """Maximum value of the variable."""
    def __init__(self, baseline_variable=None):
        self.name = self.__class__.__name__
        attr = {
            name: value
            for name, value in self.__class__.__dict__.items()
            if not name.startswith("__")
        }
        # Allow inheritance for some properties
        INHERITED_ALLOWED_PROPERTIES = (
            "label",
            "value_type",
            "entity",
            "definition_period",
        )
        for property_name in INHERITED_ALLOWED_PROPERTIES:
            if not attr.get(property_name) and property_name in dir(
                self.__class__
            ):
                attr[property_name] = getattr(self, property_name)
        self.baseline_variable = baseline_variable
        self.value_type = self.set(
            attr,
            "value_type",
            required=True,
            allowed_values=config.VALUE_TYPES.keys(),
        )
        self.dtype = config.VALUE_TYPES[self.value_type]["dtype"]
        self.json_type = config.VALUE_TYPES[self.value_type]["json_type"]
        if self.value_type == Enum:
            self.possible_values: Type[Enum] = self.set(
                attr,
                "possible_values",
                required=True,
                setter=self.set_possible_values,
            )
        if self.value_type == str:
            self.max_length = self.set(attr, "max_length", allowed_type=int)
            if self.max_length:
                self.dtype = "|S{}".format(self.max_length)
        if self.value_type == Enum:
            self.default_value = self.set(
                attr,
                "default_value",
                allowed_type=self.possible_values,
                required=True,
            )
        else:
            self.default_value = self.set(
                attr,
                "default_value",
                allowed_type=self.value_type,
                default=config.VALUE_TYPES[self.value_type].get("default"),
            )
        self.entity = self.set(
            attr, "entity", required=True, setter=self.set_entity
        )
        self.definition_period = self.set(
            attr,
            "definition_period",
            required=True,
            allowed_values=(
                periods.DAY,
                periods.MONTH,
                periods.YEAR,
                periods.ETERNITY,
            ),
        )
        self.label = self.set(
            attr, "label", allowed_type=str, setter=self.set_label
        )
        if self.label is None:
            raise ValueError(
                'Variable "{name}" has no label'.format(name=self.name)
            )
        self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end)
        self.reference = self.set(attr, "reference", setter=self.set_reference)
        self.cerfa_field = self.set(
            attr, "cerfa_field", allowed_type=(str, dict)
        )
        self.unit = self.set(attr, "unit", allowed_type=str)
        self.quantity_type = self.set(
            attr,
            "quantity_type",
            required=False,
            allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
            default=(
                QuantityType.STOCK
                if (
                    self.value_type in (bool, int, Enum, str, datetime.date)
                    or self.unit == "/1"
                )
                else QuantityType.FLOW
            ),
        )
        self.documentation = self.set(
            attr,
            "documentation",
            allowed_type=str,
            setter=self.set_documentation,
        )
        self.set_input = self.set_set_input(
            attr.pop(
                "set_input",
                (
                    set_input_dispatch_by_period
                    if self.quantity_type == QuantityType.STOCK
                    else set_input_divide_by_period
                ),
            )
        )
        if self.definition_period in (DAY, ETERNITY):
            self.set_input = None
        self.calculate_output = self.set_calculate_output(
            attr.pop("calculate_output", None)
        )
        self.is_period_size_independent = self.set(
            attr,
            "is_period_size_independent",
            allowed_type=bool,
            default=config.VALUE_TYPES[self.value_type][
                "is_period_size_independent"
            ],
        )
        self.defined_for = self.set_defined_for(attr.pop("defined_for", None))
        self.metadata = self.set(attr, "metadata", allowed_type=dict)
        self.category = self.set(
            attr,
            "category",
            allowed_type=str,
            default=None,
        )
        self.module_name = self.set(
            attr,
            "module_name",
            allowed_type=str,
            default=None,
        )
        self.index_in_module = self.set(
            attr,
            "index_in_module",
            allowed_type=int,
            default=None,
        )
        self.adds = self.set(attr, "adds")
        self.subtracts = self.set(attr, "subtracts")
        self.uprating = self.set(attr, "uprating", allowed_type=str)
        self.hidden_input = self.set(
            attr, "hidden_input", allowed_type=bool, default=False
        )
        self.requires_computation_after = self.set(
            attr, "requires_computation_after", allowed_type=str
        )
        self.exhaustive_parameter_dependencies = self.set(
            attr, "exhaustive_parameter_dependencies"
        )
        if isinstance(self.exhaustive_parameter_dependencies, str):
            self.exhaustive_parameter_dependencies = [
                self.exhaustive_parameter_dependencies
            ]
        self.min_value = self.set(
            attr,
            "min_value",
            allowed_type=(float, int),
            setter=self.set_min_value,
        )
        self.max_value = self.set(
            attr,
            "max_value",
            allowed_type=(float, int),
            setter=self.set_max_value,
        )
        formulas_attr, unexpected_attrs = helpers._partition(
            attr,
            lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX),
        )
        self.formulas = self.set_formulas(formulas_attr)
        if unexpected_attrs:
            raise ValueError(
                'Unexpected attributes in definition of variable "{}": {!r}'.format(
                    self.name, ", ".join(sorted(unexpected_attrs.keys()))
                )
            )
        self.is_neutralized = False
    # ----- Setters used to build the variable ----- #
    def set(
        self,
        attributes,
        attribute_name,
        required=False,
        allowed_values=None,
        allowed_type=None,
        setter=None,
        default=None,
    ):
        value = attributes.pop(attribute_name, None)
        if value is None and self.baseline_variable:
            return getattr(self.baseline_variable, attribute_name)
        if required and value is None:
            raise ValueError(
                "Missing attribute '{}' in definition of variable '{}'.".format(
                    attribute_name, self.name
                )
            )
        if (
            required
            and allowed_values is not None
            and value not in allowed_values
        ):
            raise ValueError(
                "Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'.".format(
                    value, attribute_name, self.name, allowed_values
                )
            )
        if (
            allowed_type is not None
            and value is not None
            and not isinstance(value, allowed_type)
        ):
            if allowed_type == float and isinstance(value, int):
                value = float(value)
            else:
                raise ValueError(
                    "Invalid value '{}' for attribute '{}' in variable '{}'. Must be of type '{}'.".format(
                        value, attribute_name, self.name, allowed_type
                    )
                )
        if setter is not None:
            value = setter(value)
        if value is None and default is not None:
            return default
        return value
    def set_entity(self, entity):
        if not isinstance(entity, Entity):
            raise ValueError(
                f"Invalid value '{entity}' for attribute 'entity' in variable '{self.name}'. Must be an instance of Entity."
            )
        return entity
    def set_possible_values(self, possible_values):
        if not issubclass(possible_values, Enum):
            raise ValueError(
                "Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}.".format(
                    possible_values, self.name, Enum
                )
            )
        return possible_values
    def set_label(self, label):
        if label:
            return label
    def set_end(self, end):
        if end:
            try:
                return datetime.datetime.strptime(end, "%Y-%m-%d").date()
            except ValueError:
                raise ValueError(
                    "Incorrect 'end' attribute format in '{}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {}".format(
                        self.name, end
                    )
                )
    def set_reference(self, reference):
        if reference:
            if isinstance(reference, str):
                reference = [reference]
            elif isinstance(reference, list):
                pass
            elif isinstance(reference, tuple):
                reference = list(reference)
            elif isinstance(reference, dict):
                reference = [reference]
        return reference
    def set_documentation(self, documentation):
        if documentation:
            return textwrap.dedent(documentation)
    def set_set_input(self, set_input):
        if not set_input and self.baseline_variable:
            return self.baseline_variable.set_input
        return set_input
    def set_calculate_output(self, calculate_output):
        if not calculate_output and self.baseline_variable:
            return self.baseline_variable.calculate_output
        return calculate_output
    def set_formulas(self, formulas_attr):
        formulas = sortedcontainers.sorteddict.SortedDict()
        for formula_name, formula in formulas_attr.items():
            starting_date = self.parse_formula_name(formula_name)
            if self.end is not None and starting_date > self.end:
                raise ValueError(
                    'You declared that "{}" ends on "{}", but you wrote a formula to calculate it from "{}" ({}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.'.format(
                        self.name, self.end, starting_date, formula_name
                    )
                )
            formulas[str(starting_date)] = formula
        # If the variable is reforming a baseline variable, keep the formulas from the latter when they are not overridden by new formulas.
        if self.baseline_variable is not None:
            first_reform_formula_date = (
                formulas.peekitem(0)[0] if formulas else None
            )
            formulas.update(
                {
                    baseline_start_date: baseline_formula
                    for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items()
                    if first_reform_formula_date is None
                    or baseline_start_date < first_reform_formula_date
                }
            )
        return formulas
    def set_defined_for(self, defined_for):
        if isinstance(defined_for, Enum):
            defined_for = defined_for.value
        return defined_for
    def set_min_value(self, min_value):
        if min_value is not None:
            if self.max_value is not None and min_value > self.max_value:
                raise ValueError("min_value cannot be greater than max_value")
            return min_value
    def set_max_value(self, max_value):
        if max_value is not None:
            if self.min_value is not None and max_value < self.min_value:
                raise ValueError("max_value cannot be smaller than min_value")
            return max_value
    # ----- Methods ----- #
[docs]    @classmethod
    def get_introspection_data(cls, tax_benefit_system):
        """
        Get instrospection data about the code of the variable.
        :returns: (comments, source file path, source code, start line number)
        :rtype: tuple
        """
        comments = inspect.getcomments(cls)
        # Handle dynamically generated variable classes or Jupyter Notebooks, which have no source.
        try:
            absolute_file_path = inspect.getsourcefile(cls)
        except TypeError:
            source_file_path = None
        else:
            source_file_path = absolute_file_path.replace(
                tax_benefit_system.get_package_metadata()["location"], ""
            )
        try:
            source_lines, start_line_number = inspect.getsourcelines(cls)
            source_code = textwrap.dedent("".join(source_lines))
        except (IOError, TypeError):
            source_code, start_line_number = None, None
        return comments, source_file_path, source_code, start_line_number 
    def clone(self):
        clone = self.__class__()
        return clone
    def check_set_value(self, value):
        if self.value_type == Enum and isinstance(value, str):
            try:
                value = self.possible_values[value].index
            except KeyError:
                possible_values = [item.name for item in self.possible_values]
                raise ValueError(
                    "'{}' is not a known value for '{}'. Possible values are ['{}'].".format(
                        value, self.name, "', '".join(possible_values)
                    )
                )
        if self.value_type in (float, int) and isinstance(value, str):
            try:
                value = tools.eval_expression(value)
            except SyntaxError:
                raise ValueError(
                    "I couldn't understand '{}' as a value for '{}'".format(
                        value, self.name
                    )
                )
        try:
            value = numpy.array([value], dtype=self.dtype)[0]
        except (TypeError, ValueError):
            if self.value_type == datetime.date:
                error_message = "Can't deal with date: '{}'.".format(value)
            else:
                error_message = "Can't deal with value: expected type {}, received '{}'.".format(
                    self.json_type, value
                )
            raise ValueError(error_message)
        except OverflowError:
            error_message = "Can't deal with value: '{}', it's too large for type '{}'.".format(
                value, self.json_type
            )
            raise ValueError(error_message)
        return value
    def default_array(self, array_size):
        array = numpy.empty(array_size, dtype=self.dtype)
        if self.value_type == Enum:
            array.fill(self.default_value.index)
            return EnumArray(array, self.possible_values)
        array.fill(self.default_value)
        return array