Source code for policyengine_core.tools.test_runner

# -*- coding: utf-8 -*-

import logging
import sys
import os
import traceback
import textwrap
from typing import Dict, List

import pytest

from policyengine_core.enums import EnumArray
from policyengine_core.tools import (
    assert_enum_equals,
    assert_datetime_equals,
    eval_expression,
)
from policyengine_core.simulations import SimulationBuilder
from policyengine_core.errors import (
    SituationParsingError,
    VariableNotFoundError,
)
from policyengine_core.scripts import build_tax_benefit_system
from policyengine_core.reforms import Reform, set_parameter
from policyengine_core.populations import ADD, DIVIDE

log = logging.getLogger(__name__)


def import_yaml():
    import yaml

    try:
        from yaml import CLoader as Loader
    except ImportError:
        log.warning(
            " "
            "libyaml is not installed in your environment, this can make your "
            "test suite slower to run. Once you have installed libyaml, run `pip "
            "uninstall pyyaml && pip install pyyaml --no-cache-dir` so that it is used in your "
            "Python environment."
        )
        from yaml import SafeLoader as Loader
    return yaml, Loader


TEST_KEYWORDS = {
    "absolute_error_margin",
    "description",
    "extensions",
    "ignore_variables",
    "input",
    "keywords",
    "name",
    "only_variables",
    "output",
    "period",
    "reforms",
    "relative_error_margin",
}

yaml, Loader = import_yaml()

_tax_benefit_system_cache: Dict = {}


[docs]def run_tests(tax_benefit_system, paths, options=None): """ Runs all the YAML tests contained in a file or a directory. If `path` is a directory, subdirectories will be recursively explored. :param TaxBenefitSystem tax_benefit_system: the tax-benefit system to use to run the tests :param (str/list) paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. :param dict options: See more details below. :raises AssertionError: if a test does not pass :return: the number of sucessful tests excecuted **Testing options**: +-------------------------------+-----------+-------------------------------------------+ | Key | Type | Role | +===============================+===========+===========================================+ | verbose | ``bool`` | | +-------------------------------+-----------+ See :any:`openfisca_test` options doc + | name_filter | ``str`` | | +-------------------------------+-----------+-------------------------------------------+ """ # Add PyTest config arguments here. We use the tb (traceback) option of "no" # to avoid printing tons of traceback lines. Remove it to use the openfisca default. argv = ["--capture", "no", "--maxfail", "0", "--tb", "short"] if options is not None and options.get("pdb"): argv.append("--pdb") if isinstance(paths, str): paths = [paths] if options is None: options = {} return pytest.main( [*argv, *paths], plugins=[OpenFiscaPlugin(tax_benefit_system, options)], )
class YamlFile(pytest.File): def __init__(self, *, tax_benefit_system, options, **kwargs): super(YamlFile, self).__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options def collect(self): try: tests = yaml.load(self.path.open(), Loader=Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): message = os.linesep.join( [ traceback.format_exc(), f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", ] ) raise ValueError(message) if not isinstance(tests, list): tests: List[Dict] = [tests] for test in tests: if not self.should_ignore(test): yield YamlItem.from_parent( self, name="", baseline_tax_benefit_system=self.tax_benefit_system, test=test, options=self.options, ) def should_ignore(self, test): name_filter = self.options.get("name_filter") return ( name_filter is not None and name_filter not in os.path.splitext(self.fspath.basename)[0] and name_filter not in test.get("name", "") and name_filter not in test.get("keywords", []) ) class YamlItem(pytest.Item): """ Terminal nodes of the test collection tree. """ def __init__( self, *, baseline_tax_benefit_system, test, options, **kwargs ): super(YamlItem, self).__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = test self.simulation = None self.tax_benefit_system = None def runtest(self): self.name = self.test.get("name", "") if not self.test.get("output"): raise ValueError( "Missing key 'output' in test '{}' in file '{}'".format( self.name, self.fspath ) ) if not TEST_KEYWORDS.issuperset(self.test.keys()): unexpected_keys = set(self.test.keys()).difference(TEST_KEYWORDS) raise ValueError( "Unexpected keys {} in test '{}' in file '{}'".format( unexpected_keys, self.name, self.fspath ) ) builder = SimulationBuilder() unsafe_input = self.test.get("input", {}) if unsafe_input is None: unsafe_input = {} period = self.test.get("period") input = {} inline_reforms = [] parametric_reform_items = [] for key, value in unsafe_input.items(): if "." in key: inline_reforms += [ set_parameter( key, value, return_modifier=True, period=f"year:2000:40", ) ] parametric_reform_items.append((key, value)) else: input[key] = value if len(inline_reforms) == 0: inline_reform = [] else: class inline_reform_class(Reform): def apply(self): for modifier in inline_reforms: self.parameters = modifier(self.parameters) inline_reform = [inline_reform_class] reforms = self.test.get("reforms", []) if isinstance(reforms, str): reforms = [reforms] if not inline_reforms: inline_reform = [] self.tax_benefit_system = _get_tax_benefit_system( self.baseline_tax_benefit_system, reforms + inline_reform, self.test.get("extensions", []), reform_key="=".join( [f"{key}:{value}" for key, value in parametric_reform_items] ), ) verbose = self.options.get("verbose") performance_graph = self.options.get("performance_graph") performance_tables = self.options.get("performance_tables") visualize = self.options.get("visualize") try: builder.set_default_period(period) self.simulation = builder.build_from_dict( self.tax_benefit_system, input ) except (VariableNotFoundError, SituationParsingError): raise except Exception as e: error_message = os.linesep.join( [ str(e), "", f"Unexpected error raised while parsing '{self.fspath}'", ] ) raise ValueError(error_message).with_traceback( sys.exc_info()[2] ) from e # Keep the stack trace from the root error try: self.simulation.trace = ( verbose or performance_graph or performance_tables or visualize ) self.check_output() finally: tracer = self.simulation.tracer if verbose: self.print_computation_log(tracer) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) if visualize: self.generate_variable_graph(tracer) def print_computation_log(self, tracer): print("Computation log:") # noqa T001 tracer.print_computation_log() def generate_performance_graph(self, tracer): tracer.generate_performance_graph(".") def generate_performance_tables(self, tracer): tracer.generate_performance_tables(".") def generate_variable_graph(self, tracer): tracer.generate_variable_graph( self.test.get("name"), self._all_output_vars() ) def _all_output_vars(self): return self._get_leaf_keys(self.test["output"]) def _get_leaf_keys(self, dictionary: dict): keys = [] for key, value in dictionary.items(): if type(value) is dict: keys.extend(self._get_leaf_keys(value)) else: keys.append(key) return keys def check_output(self): output = self.test.get("output") if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable( key ): # If key is a variable self.check_variable( key, expected_value, self.test.get("period") ) elif self.simulation.populations.get( key ): # If key is an entity singular for variable_name, value in expected_value.items(): self.check_variable( variable_name, value, self.test.get("period") ) else: population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable( variable_name, value, self.test.get("period"), entity_index, ) else: raise VariableNotFoundError(key, self.tax_benefit_system) def check_variable( self, variable_name, expected_value, period, entity_index=None ): if self.should_ignore_variable(variable_name): return if isinstance(expected_value, dict): for ( requested_period, expected_value_at_period, ) in expected_value.items(): self.check_variable( variable_name, expected_value_at_period, requested_period, entity_index, ) return actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] return assert_near( actual_value, expected_value, absolute_error_margin=self.test.get("absolute_error_margin"), message=f"{variable_name}@{period}: ", relative_error_margin=self.test.get("relative_error_margin"), ) def should_ignore_variable(self, variable_name): only_variables = self.options.get("only_variables") ignore_variables = self.options.get("ignore_variables") variable_ignored = ( ignore_variables is not None and variable_name in ignore_variables ) variable_not_tested = ( only_variables is not None and variable_name not in only_variables ) return variable_ignored or variable_not_tested def repr_failure(self, excinfo): if not isinstance( excinfo.value, (AssertionError, VariableNotFoundError, SituationParsingError), ): return super(YamlItem, self).repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" return os.linesep.join( [ f"{str(self.fspath)}:", f" Test '{str(self.name)}':", textwrap.indent(message, " "), ] ) class OpenFiscaPlugin(object): def __init__(self, tax_benefit_system, options): self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, file_path): """ Called by pytest for all plugins. :return: The collector for test methods. """ if file_path.suffix in [".yaml", ".yml"]: return YamlFile.from_parent( parent, path=file_path, tax_benefit_system=self.tax_benefit_system, options=self.options, ) def _get_tax_benefit_system( baseline, reforms, extensions, reform_key=None, ): if not isinstance(reforms, list): reforms = [reforms] if not isinstance(extensions, list): extensions = [extensions] # keep reforms order in cache, ignore extensions order key = hash( ( id(baseline), ":".join( [ reform if isinstance(reform, str) else "" for reform in reforms ] ), reform_key, frozenset(extensions), ) ) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) current_tax_benefit_system = baseline.clone() for reform_path in reforms: if isinstance(reform_path, str): current_tax_benefit_system = ( current_tax_benefit_system.apply_reform(reform_path) ) else: current_tax_benefit_system = reform_path( current_tax_benefit_system ) current_tax_benefit_system._parameters_at_instant_cache = {} for extension in extensions: current_tax_benefit_system = current_tax_benefit_system.clone() current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system return current_tax_benefit_system def assert_near( value, target_value, absolute_error_margin=None, message="", relative_error_margin=None, ): """ :param value: Value returned by the test :param target_value: Value that the test should return to pass :param absolute_error_margin: Absolute error margin authorized :param message: Error message to be displayed if the test fails :param relative_error_margin: Relative error margin authorized Limit : This function cannot be used to assert near periods. """ import numpy as np if absolute_error_margin is None and relative_error_margin is None: absolute_error_margin = 1e-3 if not isinstance(value, np.ndarray): value = np.array(value) if isinstance(value, EnumArray): return assert_enum_equals(value, target_value, message) if np.issubdtype(value.dtype, np.datetime64): target_value = np.array(target_value, dtype=value.dtype) assert_datetime_equals(value, target_value, message) if isinstance(target_value, str): target_value = eval_expression(target_value) try: target_value = np.array(target_value).astype(np.float32) value = np.array(value).astype(np.float32) except ValueError: # Data type not translatable to floating point, assert complete equality assert np.array(value) == np.array( target_value ), "{}{} differs from {}".format(message, value, target_value) return diff = abs(target_value - value) if absolute_error_margin is not None: assert ( diff <= absolute_error_margin ).all(), "{}{} differs from {} with an absolute margin {} > {}".format( message, value, target_value, diff, absolute_error_margin ) if relative_error_margin is not None: assert ( diff <= abs(relative_error_margin * target_value) ).all(), "{}{} differs from {} with a relative margin {} > {}".format( message, value, target_value, diff, abs(relative_error_margin * target_value), )