from __future__ import annotations
import abc
import bisect
import os
import typing
import numpy
from policyengine_core import tools
from policyengine_core.errors import EmptyArgumentError
from policyengine_core.taxscales.tax_scale_like import TaxScaleLike
if typing.TYPE_CHECKING:
NumericalArray = typing.Union[numpy.int_, numpy.float_]
[docs]class RateTaxScaleLike(TaxScaleLike, abc.ABC):
"""
Base class for various types of rate-based tax scales: marginal rate,
linear average rate...
"""
rates: typing.List
def __init__(
self,
name: typing.Optional[str] = None,
option: typing.Any = None,
unit: typing.Any = None,
) -> None:
super().__init__(name, option, unit)
self.rates = []
def __repr__(self) -> str:
return tools.indent(
os.linesep.join(
[
f"- threshold: {threshold}{os.linesep} rate: {rate}"
for (threshold, rate) in zip(self.thresholds, self.rates)
]
)
)
def add_bracket(
self,
threshold: typing.Union[int, float],
rate: typing.Union[int, float],
) -> None:
if threshold in self.thresholds:
i = self.thresholds.index(threshold)
self.rates[i] += rate
else:
i = bisect.bisect_left(self.thresholds, threshold)
self.thresholds.insert(i, threshold)
self.rates.insert(i, rate)
def multiply_rates(
self,
factor: float,
inplace: bool = True,
new_name: typing.Optional[str] = None,
) -> RateTaxScaleLike:
if inplace:
assert new_name is None
for i, rate in enumerate(self.rates):
self.rates[i] = rate * factor
return self
new_tax_scale = self.__class__(
new_name or self.name,
option=self.option,
unit=self.unit,
)
for threshold, rate in zip(self.thresholds, self.rates):
new_tax_scale.thresholds.append(threshold)
new_tax_scale.rates.append(rate * factor)
return new_tax_scale
def multiply_thresholds(
self,
factor: float,
decimals: typing.Optional[int] = None,
inplace: bool = True,
new_name: typing.Optional[str] = None,
) -> RateTaxScaleLike:
if inplace:
assert new_name is None
for i, threshold in enumerate(self.thresholds):
if decimals is not None:
self.thresholds[i] = numpy.around(
threshold * factor,
decimals=decimals,
)
else:
self.thresholds[i] = threshold * factor
return self
new_tax_scale = self.__class__(
new_name or self.name,
option=self.option,
unit=self.unit,
)
for threshold, rate in zip(self.thresholds, self.rates):
if decimals is not None:
new_tax_scale.thresholds.append(
numpy.around(threshold * factor, decimals=decimals),
)
else:
new_tax_scale.thresholds.append(threshold * factor)
new_tax_scale.rates.append(rate)
return new_tax_scale
[docs] def bracket_indices(
self,
tax_base: NumericalArray,
factor: float = 1.0,
round_decimals: typing.Optional[int] = None,
) -> numpy.int_:
"""
Compute the relevant bracket indices for the given tax bases.
:param ndarray tax_base: Array of the tax bases.
:param float factor: Factor to apply to the thresholds.
:param int round_decimals: Decimals to keep when rounding thresholds.
:returns: Integer array with relevant bracket indices for the given tax
bases.
For instance:
>>> tax_scale = LinearAverageRateTaxScale()
>>> tax_scale.add_bracket(0, 0)
>>> tax_scale.add_bracket(100, 0.1)
>>> tax_base = array([0, 150])
>>> tax_scale.bracket_indices(tax_base)
[0, 1]
"""
if not numpy.size(numpy.array(self.thresholds)):
raise EmptyArgumentError(
self.__class__.__name__,
"bracket_indices",
"self.thresholds",
self.thresholds,
)
if not numpy.size(numpy.asarray(tax_base)):
raise EmptyArgumentError(
self.__class__.__name__,
"bracket_indices",
"tax_base",
tax_base,
)
base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T
factor = numpy.ones(len(tax_base)) * factor
# To avoid the creation of:
#
# numpy.nan = 0 * numpy.inf
#
# We use:
#
# numpy.finfo(float_).eps
thresholds1 = numpy.outer(
+factor + numpy.finfo(numpy.float_).eps,
numpy.array(self.thresholds),
)
if round_decimals is not None:
thresholds1 = numpy.round_(thresholds1, round_decimals)
return (base1 - thresholds1 >= 0).sum(axis=1) - 1
[docs] def threshold_from_tax_base(
self,
tax_base: NumericalArray,
) -> NumericalArray:
"""
Compute the relevant thresholds for the given tax bases.
:param: ndarray tax_base: Array of the tax bases.
:returns: Floating array with relevant thresholds
for the given tax bases.
For instance:
>>> import numpy
>>> from policyengine_core import taxscales
>>> tax_scale = taxscales.MarginalRateTaxScale()
>>> tax_scale.add_bracket(0, 0)
>>> tax_scale.add_bracket(200, 0.1)
>>> tax_scale.add_bracket(500, 0.25)
>>> tax_base = numpy.array([450, 1_150, 10])
>>> tax_scale.threshold_from_tax_base(tax_base)
array([200, 500, 0])
"""
return numpy.array(self.thresholds)[self.bracket_indices(tax_base)]
def to_dict(self) -> dict:
return {
str(threshold): self.rates[index]
for index, threshold in enumerate(self.thresholds)
}