Source code for policyengine_core.enums.enum_array

from __future__ import annotations

import typing
from typing import Any, NoReturn, Optional, Type

import numpy

if typing.TYPE_CHECKING:
    from policyengine_core.enums import Enum


[docs]class EnumArray(numpy.ndarray): """ Numpy array subclass representing an array of enum items. EnumArrays are encoded as ``int`` arrays to improve performance """ # Subclassing ndarray is a little tricky. # To read more about the two following methods, see: # https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array. def __new__( cls, input_array: numpy.int_, possible_values: Optional[Type[Enum]] = None, ) -> EnumArray: obj = numpy.asarray(input_array).view(cls) obj.possible_values = possible_values return obj # See previous comment def __array_finalize__(self, obj: Optional[numpy.int_]) -> None: if obj is None: return self.possible_values = getattr(obj, "possible_values", None) def __eq__(self, other: Any) -> bool: # When comparing to an item of self.possible_values, use the item index # to speed up the comparison. if other.__class__.__name__ is self.possible_values.__name__: # Use view(ndarray) so that the result is a classic ndarray, not an # EnumArray. return self.view(numpy.ndarray) == other.index return self.view(numpy.ndarray) == other def __ne__(self, other: Any) -> bool: return numpy.logical_not(self == other) def _forbidden_operation(self, other: Any) -> NoReturn: raise TypeError( "Forbidden operation. The only operations allowed on EnumArrays " "are '==' and '!='.", ) __add__ = _forbidden_operation __mul__ = _forbidden_operation __lt__ = _forbidden_operation __le__ = _forbidden_operation __gt__ = _forbidden_operation __ge__ = _forbidden_operation __and__ = _forbidden_operation __or__ = _forbidden_operation
[docs] def decode(self) -> numpy.object_: """ Return the array of enum items corresponding to self. For instance: >>> enum_array = household('housing_occupancy_status', period) >>> enum_array[0] >>> 2 # Encoded value >>> enum_array.decode()[0] <HousingOccupancyStatus.free_lodger: 'Free lodger'> Decoded value: enum item """ return numpy.select( [self == item.index for item in self.possible_values], list(self.possible_values), )
[docs] def decode_to_str(self) -> numpy.str_: """ Return the array of string identifiers corresponding to self. For instance: >>> enum_array = household('housing_occupancy_status', period) >>> enum_array[0] >>> 2 # Encoded value >>> enum_array.decode_to_str()[0] 'free_lodger' # String identifier """ return numpy.select( [self == item.index for item in self.possible_values], [item.name for item in self.possible_values], )
def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self.decode())})" def __str__(self) -> str: return str(self.decode_to_str())