Weight calibration#
The survey_enhance.reweight
module contains two classes which can be used to define a loss architecture and minimise it for a given survey.
- class survey_enhance.reweight.CalibratedWeights(initial_weights: numpy.ndarray, dataset: policyengine_core.data.dataset.Dataset, loss_type: Type[survey_enhance.reweight.LossCategory], calibration_parameters: policyengine_core.parameters.parameter_node.ParameterNode)[source]#
Bases:
object
- calibrate(time_instant: str, epochs: int = 1000, min_loss: Optional[float] = None, learning_rate: float = 0.1, validation_split: float = 0.0, validation_blacklist: Optional[List[str]] = None, rotate_holdout_sets: bool = False, log_dir: Optional[str] = None, tensorboard_log_dir: Optional[str] = None, log_frequency: int = 15, verbose: bool = False) numpy.ndarray [source]#
- calibration_parameters: policyengine_core.parameters.parameter_node.ParameterNode#
- dataset: policyengine_core.data.dataset.Dataset#
- initial_weights: numpy.ndarray#
- loss_type: Type[survey_enhance.reweight.LossCategory]#
- class survey_enhance.reweight.LossCategory(dataset: policyengine_core.data.dataset.Dataset, calibration_parameters_at_instant: policyengine_core.parameters.parameter_node_at_instant.ParameterNodeAtInstant, instant: Optional[str] = None, calibration_parameters: Optional[policyengine_core.parameters.parameter_node.ParameterNode] = None, weight: Optional[float] = None, ancestor: Optional[survey_enhance.reweight.LossCategory] = None, static_dataset: Optional[bool] = None, comparison_white_list: Optional[List[str]] = None, comparison_black_list: Optional[List[str]] = None, name: Optional[str] = None, normalise: Optional[bool] = None, diagnostic: Optional[bool] = None)[source]#
Bases:
torch.nn.modules.module.Module
A loss category is essentially a loss function, but contains a number of utilities for ease of programming, like decomposition into weighted and normalised subcategories, and logging.
- computation_tree(household_weights: torch.Tensor, dataset: policyengine_core.data.dataset.Dataset, filter_non_one: bool = True) dict [source]#
- create_holdout_sets(dataset: policyengine_core.data.dataset.Dataset, num_sets: int, num_weights: int, exclude_by_name: Optional[str] = None) List[Tuple[policyengine_core.data.dataset.Dataset, policyengine_core.data.dataset.Dataset]] [source]#
- diagnostic: bool = False#
Whether to log the full tree of losses.
- diagnostic_tree: Dict[str, float] = None#
The tree of losses.
- evaluate(household_weights: torch.Tensor, dataset: policyengine_core.data.dataset.Dataset, initial_run: bool = False) torch.Tensor [source]#
- forward(household_weights: torch.Tensor, dataset: policyengine_core.data.dataset.Dataset, initial_run: bool = False) torch.Tensor [source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_comparisons(dataset: policyengine_core.data.dataset.Dataset) List[Tuple[str, float, torch.Tensor]] [source]#
- normalise: bool = True#
Whether to normalise the starting loss value to 1.
- static_dataset = False#
Whether the dataset is static, i.e. does not change between epochs.
- subcategories: List[Type[LossCategory]] = []#
The subcategories of this loss category.
- weight: float = 1.0#
The weight of this loss category in the total loss.