Source code for policyengine_core.taxscales.marginal_rate_tax_scale

from __future__ import annotations

import bisect
import itertools
import typing

import numpy

from policyengine_core import taxscales
from policyengine_core.taxscales.rate_tax_scale_like import RateTaxScaleLike

if typing.TYPE_CHECKING:
    NumericalArray = typing.Union[numpy.int_, numpy.float_]


[docs]class MarginalRateTaxScale(RateTaxScaleLike): def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None: # So as not to have problems with empty scales if len(tax_scale.thresholds) > 0: for threshold_low, threshold_high, rate in zip( tax_scale.thresholds[:-1], tax_scale.thresholds[1:], tax_scale.rates, ): self.combine_bracket(rate, threshold_low, threshold_high) # To process the last threshold self.combine_bracket( tax_scale.rates[-1], tax_scale.thresholds[-1], )
[docs] def calc( self, tax_base: NumericalArray, factor: float = 1.0, round_base_decimals: typing.Optional[int] = None, ) -> numpy.float_: """ Compute the tax amount for the given tax bases by applying a taxscale. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of the taxscale. :param int round_base_decimals: Decimals to keep when rounding thresholds. :returns: Float array with tax amount for the given tax bases. For instance: >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0) >>> tax_scale.add_bracket(100, 0.1) >>> tax_base = array([0, 150]) >>> tax_scale.calc(tax_base) [0.0, 5.0] """ base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T factor = numpy.ones(len(tax_base)) * factor # To avoid the creation of: # # numpy.nan = 0 * numpy.inf # # We use: # # numpy.finfo(float_).eps thresholds1 = numpy.outer( factor + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds + [numpy.inf]), ) if round_base_decimals is not None: thresholds1 = numpy.round_(thresholds1, round_base_decimals) a = numpy.maximum( numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 ) if round_base_decimals is None: return numpy.dot(self.rates, a.T) else: r = numpy.tile(self.rates, (len(tax_base), 1)) b = numpy.round_(a, round_base_decimals) return numpy.round_(r * b, round_base_decimals).sum(axis=1)
def combine_bracket( self, rate: typing.Union[int, float], threshold_low: int = 0, threshold_high: typing.Union[int, bool] = False, ) -> None: # Insert threshold_low and threshold_high without modifying rates if threshold_low not in self.thresholds: index = bisect.bisect_right(self.thresholds, threshold_low) - 1 self.add_bracket(threshold_low, self.rates[index]) if threshold_high and threshold_high not in self.thresholds: index = bisect.bisect_right(self.thresholds, threshold_high) - 1 self.add_bracket(threshold_high, self.rates[index]) # Use add_bracket to add rates where they belongs i = self.thresholds.index(threshold_low) if threshold_high: j = self.thresholds.index(threshold_high) - 1 else: j = len(self.thresholds) - 1 while i <= j: self.add_bracket(self.thresholds[i], rate) i += 1
[docs] def marginal_rates( self, tax_base: NumericalArray, factor: float = 1.0, round_base_decimals: typing.Optional[int] = None, ) -> numpy.float_: """ Compute the marginal tax rates relevant for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of a tax scale. :param int round_base_decimals: Decimals to keep when rounding thresholds. :returns: Float array with relevant marginal tax rate for the given tax bases. For instance: >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0) >>> tax_scale.add_bracket(100, 0.1) >>> tax_base = array([0, 150]) >>> tax_scale.marginal_rates(tax_base) [0.0, 0.1] """ bracket_indices = self.bracket_indices( tax_base, factor, round_base_decimals, ) return numpy.array(self.rates)[bracket_indices]
[docs] def rate_from_bracket_indice( self, bracket_indice: numpy.int_, ) -> numpy.float_: """ Compute the relevant tax rates for the given bracket indices. :param: ndarray bracket_indice: Array of the bracket indices. :returns: Floating array with relevant tax rates for the given bracket indices. For instance: >>> import numpy >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0) >>> tax_scale.add_bracket(200, 0.1) >>> tax_scale.add_bracket(500, 0.25) >>> tax_base = numpy.array([50, 1_000, 250]) >>> bracket_indice = tax_scale.bracket_indices(tax_base) >>> tax_scale.rate_from_bracket_indice(bracket_indice) array([0. , 0.25, 0.1 ]) """ if bracket_indice.max() > len(self.rates) - 1: raise IndexError( f"bracket_indice parameter ({bracket_indice}) " f"contains one or more bracket indice which is unavailable " f"inside current {self.__class__.__name__} :\n" f"{self}" ) return numpy.array(self.rates)[bracket_indice]
[docs] def rate_from_tax_base( self, tax_base: NumericalArray, ) -> numpy.float_: """ Compute the relevant tax rates for the given tax bases. :param: ndarray tax_base: Array of the tax bases. :returns: Floating array with relevant tax rates for the given tax bases. For instance: >>> import numpy >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0) >>> tax_scale.add_bracket(200, 0.1) >>> tax_scale.add_bracket(500, 0.25) >>> tax_base = numpy.array([1_000, 50, 450]) >>> tax_scale.rate_from_tax_base(tax_base) array([0.25, 0. , 0.1 ]) """ return self.rate_from_bracket_indice(self.bracket_indices(tax_base))
[docs] def inverse(self) -> MarginalRateTaxScale: """ Returns a new instance of MarginalRateTaxScale. Invert a taxscale: Assume tax_scale composed of bracket whose thresholds are expressed in terms of gross revenue. The inverse is another MarginalRateTaxScale whose thresholds are expressed in terms of net revenue. If net = gross_revenue - tax_scale.calc(gross_revenue) Then gross = tax_scale.inverse().calc(net) """ # Threshold of net revenue. net_threshold: int = 0 # Threshold of gross revenue. threshold: int # The intercept of the segments of the different thresholds in a # representation of taxable revenue as a piecewise linear function # of gross revenue. theta: int # Actually 1 / (1 - global_rate) inverse = self.__class__( name=str(self.name) + "'", option=self.option, unit=self.unit, ) for threshold, rate in zip(self.thresholds, self.rates): if threshold == 0: previous_rate = 0 theta = 0 # We calculate the taxable revenue threshold of the considered # bracket. net_threshold = (1 - previous_rate) * threshold + theta inverse.add_bracket(net_threshold, 1 / (1 - rate)) theta = (rate - previous_rate) * threshold + theta previous_rate = rate return inverse
[docs] def scale_tax_scales(self, factor: float) -> MarginalRateTaxScale: """Scale all the MarginalRateTaxScales in the node.""" scaled_tax_scale = self.copy() return scaled_tax_scale.multiply_thresholds(factor)
def to_average(self) -> taxscales.LinearAverageRateTaxScale: average_tax_scale = taxscales.LinearAverageRateTaxScale( name=self.name, option=self.option, unit=self.unit, ) average_tax_scale.add_bracket(0, 0) if self.thresholds: i = 0 previous_threshold = self.thresholds[0] previous_rate = self.rates[0] for threshold, rate in itertools.islice( zip(self.thresholds, self.rates), 1, None, ): i += previous_rate * (threshold - previous_threshold) average_tax_scale.add_bracket(threshold, i / threshold) previous_threshold = threshold previous_rate = rate average_tax_scale.add_bracket(float("Inf"), rate) return average_tax_scale