Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Examples

Complete working scripts demonstrating common workflows. Each script can be run directly with python examples/<filename>.py.

US budgetary impact

The canonical workflow for comparing a baseline and reform simulation, using both economic_impact_analysis() and ChangeAggregate.

us_budgetary_impact.py
"""Example: US budgetary impact comparison between baseline and reform.

Demonstrates the canonical policyengine.py workflow:
1. Ensure datasets exist (download + compute or load from cache)
2. Define a parametric reform
3. Run baseline and reform simulations
4. Use economic_impact_analysis() for the full analysis
5. Use ChangeAggregate for targeted single-metric queries

Run: python examples/us_budgetary_impact.py
"""

import datetime

from policyengine.core import Parameter, ParameterValue, Policy, Simulation
from policyengine.outputs.change_aggregate import (
    ChangeAggregate,
    ChangeAggregateType,
)
from policyengine.tax_benefit_models.us import (
    economic_impact_analysis,
    ensure_datasets,
    us_latest,
)


def main():
    year = 2026

    # ── Step 1: Get dataset (downloads from HuggingFace on first run) ──
    print("Ensuring datasets are available...")
    datasets = ensure_datasets(
        datasets=["hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"],
        years=[year],
        data_folder="./data",
    )
    dataset = datasets[f"enhanced_cps_2024_{year}"]
    print(f"  Loaded: {dataset}")

    # ── Step 2: Define a reform ──
    # Example: double the standard deduction for single filers
    param = Parameter(
        name="gov.irs.deductions.standard.amount.SINGLE",
        tax_benefit_model_version=us_latest,
    )
    reform = Policy(
        name="Double standard deduction (single)",
        parameter_values=[
            ParameterValue(
                parameter=param,
                start_date=datetime.date(year, 1, 1),
                end_date=datetime.date(year, 12, 31),
                value=30_950,
            ),
        ],
    )

    # ── Step 3: Create simulations ──
    baseline_sim = Simulation(
        dataset=dataset,
        tax_benefit_model_version=us_latest,
    )
    reform_sim = Simulation(
        dataset=dataset,
        tax_benefit_model_version=us_latest,
        policy=reform,
    )

    # ── Step 4a: Quick budgetary number via ChangeAggregate ──
    # This requires running the simulations first.
    print("\nRunning simulations...")
    baseline_sim.run()
    reform_sim.run()

    tax_change = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_tax",
        aggregate_type=ChangeAggregateType.SUM,
    )
    tax_change.run()
    print("\nQuick budgetary result:")
    print(f"  Tax revenue change: ${tax_change.result / 1e9:.2f}B")

    # Count winners and losers
    winners = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.COUNT,
        change_geq=1,
    )
    losers = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.COUNT,
        change_leq=-1,
    )
    winners.run()
    losers.run()
    print(f"  Winners: {winners.result / 1e6:.2f}M households")
    print(f"  Losers: {losers.result / 1e6:.2f}M households")

    # ── Step 4b: Full analysis via economic_impact_analysis ──
    # Note: this calls .ensure() internally, which is a no-op here since
    # we already ran the simulations above. If we hadn't called .run(),
    # ensure() would run + cache them automatically.
    print("\nRunning full economic impact analysis...")
    analysis = economic_impact_analysis(baseline_sim, reform_sim)

    print("\n=== Program-by-Program Impact ===")
    for prog in analysis.program_statistics.outputs:
        print(
            f"  {prog.program_name:30s}  "
            f"baseline=${prog.baseline_total / 1e9:8.1f}B  "
            f"reform=${prog.reform_total / 1e9:8.1f}B  "
            f"change=${prog.change / 1e9:+8.1f}B"
        )

    print("\n=== Decile Impacts ===")
    for d in analysis.decile_impacts.outputs:
        print(
            f"  Decile {d.decile:2d}:  "
            f"avg change=${d.absolute_change:+8.0f}  "
            f"relative={d.relative_change:+.2%}"
        )

    print("\n=== Poverty ===")
    for bp, rp in zip(
        analysis.baseline_poverty.outputs,
        analysis.reform_poverty.outputs,
        strict=True,
    ):
        print(
            f"  {bp.metric:30s}  "
            f"baseline={bp.rate:.4f}  "
            f"reform={rp.rate:.4f}  "
            f"change={rp.rate - bp.rate:+.4f}"
        )

    print("\n=== Inequality ===")
    bi = analysis.baseline_inequality
    ri = analysis.reform_inequality
    print(f"  Gini:           baseline={bi.gini:.4f}  reform={ri.gini:.4f}")
    print(
        f"  Top 10% share:  baseline={bi.top_10_share:.4f}  reform={ri.top_10_share:.4f}"
    )
    print(
        f"  Top 1% share:   baseline={bi.top_1_share:.4f}  reform={ri.top_1_share:.4f}"
    )


if __name__ == "__main__":
    main()

UK policy reform analysis

Applying parametric reforms, comparing baseline and reform with ChangeAggregate, analysing winners and losers by income decile, and visualising results with Plotly.

policy_change_uk.py
"""Example: Analyse policy change impacts using ChangeAggregate with parametric reforms.

This script demonstrates:
1. Loading representative household microdata
2. Applying parametric reforms (e.g., setting personal allowance to zero)
3. Running baseline and reform simulations
4. Using ChangeAggregate to analyse winners, losers, and impact sizes by income decile
5. Using quantile filters for decile-based analysis
6. Visualising results with Plotly

Run: python examples/policy_change.py
"""

import datetime
from pathlib import Path

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from policyengine.core import Parameter, ParameterValue, Policy, Simulation
from policyengine.outputs.change_aggregate import (
    ChangeAggregate,
    ChangeAggregateType,
)
from policyengine.tax_benefit_models.uk import (
    PolicyEngineUKDataset,
    uk_latest,
)


def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset:
    """Load representative household microdata for a given year."""
    dataset_path = Path(f"./data/enhanced_frs_2023_24_year_{year}.h5")

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"Dataset not found at {dataset_path}. "
            "Run create_datasets() from policyengine.tax_benefit_models.uk first."
        )

    dataset = PolicyEngineUKDataset(
        name=f"Enhanced FRS {year}",
        description=f"Representative household microdata for {year}",
        filepath=str(dataset_path),
        year=year,
    )
    dataset.load()
    return dataset


def create_personal_allowance_reform(year: int) -> Policy:
    """Create a policy that sets personal allowance to zero."""
    parameter = Parameter(
        id=f"{uk_latest.id}-gov.hmrc.income_tax.allowances.personal_allowance.amount",
        name="gov.hmrc.income_tax.allowances.personal_allowance.amount",
        tax_benefit_model_version=uk_latest,
        description="Personal allowance for income tax",
        data_type=float,
    )

    parameter_value = ParameterValue(
        parameter=parameter,
        start_date=datetime.date(year, 1, 1),
        end_date=datetime.date(year, 12, 31),
        value=0,
    )

    return Policy(
        name="Zero personal allowance",
        description="Sets personal allowance to £0",
        parameter_values=[parameter_value],
    )


def run_baseline_simulation(dataset: PolicyEngineUKDataset) -> Simulation:
    """Run baseline microsimulation without policy changes."""
    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=uk_latest,
    )
    simulation.run()
    return simulation


def run_reform_simulation(dataset: PolicyEngineUKDataset, policy: Policy) -> Simulation:
    """Run reform microsimulation with policy changes."""
    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=uk_latest,
        policy=policy,
    )
    simulation.run()
    return simulation


def analyse_overall_impact(baseline_sim: Simulation, reform_sim: Simulation) -> dict:
    """Analyse overall winners, losers, and financial impact."""
    winners = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.COUNT,
        change_geq=1,
    )
    winners.run()

    losers = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.COUNT,
        change_leq=-1,
    )
    losers.run()

    no_change = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.COUNT,
        change_eq=0,
    )
    no_change.run()

    total_change = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_net_income",
        aggregate_type=ChangeAggregateType.SUM,
    )
    total_change.run()

    tax_revenue_change = ChangeAggregate(
        baseline_simulation=baseline_sim,
        reform_simulation=reform_sim,
        variable="household_tax",
        aggregate_type=ChangeAggregateType.SUM,
    )
    tax_revenue_change.run()

    return {
        "winners": winners.result / 1e6,  # Convert to millions
        "losers": losers.result / 1e6,
        "no_change": no_change.result / 1e6,
        "total_change": total_change.result / 1e9,  # Convert to billions
        "tax_revenue_change": tax_revenue_change.result / 1e9,
    }


def analyse_impact_by_income_decile(
    baseline_sim: Simulation, reform_sim: Simulation
) -> dict:
    """Analyse impact by income decile."""
    decile_labels = []
    decile_losers = []
    decile_avg_loss = []

    for decile in range(1, 11):
        label = f"Decile {decile}"

        # Count losers in this decile
        count_agg = ChangeAggregate(
            baseline_simulation=baseline_sim,
            reform_simulation=reform_sim,
            variable="household_net_income",
            aggregate_type=ChangeAggregateType.COUNT,
            change_leq=-1,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile,
        )
        count_agg.run()

        # Average loss for all households in this decile
        mean_agg = ChangeAggregate(
            baseline_simulation=baseline_sim,
            reform_simulation=reform_sim,
            variable="household_net_income",
            aggregate_type=ChangeAggregateType.MEAN,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile,
        )
        mean_agg.run()

        decile_labels.append(label)
        decile_losers.append(count_agg.result / 1e6)  # Convert to millions
        decile_avg_loss.append(mean_agg.result)

    return {
        "labels": decile_labels,
        "losers": decile_losers,
        "avg_loss": decile_avg_loss,
    }


def visualise_results(overall: dict, by_decile: dict, reform_name: str) -> None:
    """Create visualisations of policy change impacts."""
    fig = make_subplots(
        rows=1,
        cols=3,
        subplot_titles=(
            "Winners vs losers (millions)",
            "Losers by income decile (millions)",
            "Average loss by income decile (£)",
        ),
        specs=[[{"type": "bar"}, {"type": "bar"}, {"type": "bar"}]],
    )

    fig.add_trace(
        go.Bar(
            x=["Winners", "No change", "Losers"],
            y=[
                overall["winners"],
                overall["no_change"],
                overall["losers"],
            ],
            marker_color=["green", "gray", "red"],
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Bar(
            x=by_decile["labels"],
            y=by_decile["losers"],
            marker_color="lightcoral",
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Bar(
            x=by_decile["labels"],
            y=by_decile["avg_loss"],
            marker_color="orange",
        ),
        row=1,
        col=3,
    )

    fig.update_xaxes(title_text="Category", row=1, col=1)
    fig.update_xaxes(title_text="Income decile", row=1, col=2)
    fig.update_xaxes(title_text="Income decile", row=1, col=3)

    fig.update_layout(
        title_text=f"Policy change impact analysis: {reform_name}",
        showlegend=False,
        height=400,
    )

    fig.show()


def print_summary(overall: dict, decile: dict, reform_name: str) -> None:
    """Print summary statistics."""
    print("=" * 60)
    print(f"Policy change impact summary: {reform_name}")
    print("=" * 60)
    print("\nOverall impact:")
    print(f"  Winners: {overall['winners']:.2f}m households")
    print(f"  Losers: {overall['losers']:.2f}m households")
    print(f"  No change: {overall['no_change']:.2f}m households")
    print("\nFinancial impact:")
    print(f"  Net income change: £{overall['total_change']:.2f}bn (negative = loss)")
    print(f"  Tax revenue change: £{overall['tax_revenue_change']:.2f}bn")
    print("\nImpact by income decile:")
    for i, label in enumerate(decile["labels"]):
        print(
            f"  {label}: {decile['losers'][i]:.2f}m losers, avg change £{decile['avg_loss'][i]:.0f}"
        )
    print("=" * 60)


def main():
    """Main execution function."""
    year = 2026

    print("Loading representative household data...")
    dataset = load_representative_data(year=year)

    print("Creating policy reform (zero personal allowance)...")
    reform = create_personal_allowance_reform(year=year)

    print("Running baseline simulation...")
    baseline_sim = run_baseline_simulation(dataset)

    print("Running reform simulation...")
    reform_sim = run_reform_simulation(dataset, reform)

    print("Analysing overall impact...")
    overall_impact = analyse_overall_impact(baseline_sim, reform_sim)

    print("Analysing impact by income decile...")
    decile_impact = analyse_impact_by_income_decile(baseline_sim, reform_sim)

    print_summary(overall_impact, decile_impact, reform.name)

    print("\nGenerating visualisations...")
    visualise_results(overall_impact, decile_impact, reform.name)


if __name__ == "__main__":
    main()

UK income bands

Calculating net income and tax by income decile using representative microdata and Aggregate with quantile filters.

income_bands_uk.py
"""Example: Calculate net income and tax by income decile using representative microdata.

This script demonstrates:
1. Using representative household microdata
2. Running a full microsimulation to calculate income tax and net income
3. Using Aggregate to calculate statistics within income deciles using quantile filters
4. Visualising results with Plotly

Run: python examples/income_bands.py
"""

from pathlib import Path

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from policyengine.core import Simulation
from policyengine.outputs.aggregate import Aggregate, AggregateType
from policyengine.tax_benefit_models.uk import (
    PolicyEngineUKDataset,
    uk_latest,
)


def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset:
    """Load representative household microdata for a given year."""
    dataset_path = Path(f"./data/enhanced_frs_2023_24_year_{year}.h5")

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"Dataset not found at {dataset_path}. "
            "Run create_datasets() from policyengine.tax_benefit_models.uk first."
        )

    dataset = PolicyEngineUKDataset(
        name=f"Enhanced FRS {year}",
        description=f"Representative household microdata for {year}",
        filepath=str(dataset_path),
        year=year,
    )
    dataset.load()
    return dataset


def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation:
    """Run a microsimulation on the dataset."""
    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=uk_latest,
    )
    simulation.run()
    return simulation


def calculate_income_decile_statistics(simulation: Simulation) -> dict:
    """Calculate total income, tax, and population by income deciles."""
    deciles = []
    net_incomes = []
    taxes = []
    counts = []

    for decile in range(1, 11):
        net_income_agg = Aggregate(
            simulation=simulation,
            variable="household_net_income",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile,
        )
        net_income_agg.run()

        tax_agg = Aggregate(
            simulation=simulation,
            variable="household_tax",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile,
        )
        tax_agg.run()

        count_agg = Aggregate(
            simulation=simulation,
            variable="household_net_income",
            aggregate_type=AggregateType.COUNT,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile,
        )
        count_agg.run()

        deciles.append(f"Decile {decile}")
        net_incomes.append(net_income_agg.result / 1e9)  # Convert to billions
        taxes.append(tax_agg.result / 1e9)
        counts.append(count_agg.result / 1e6)  # Convert to millions

    return {
        "deciles": deciles,
        "net_incomes": net_incomes,
        "taxes": taxes,
        "counts": counts,
    }


def visualise_results(results: dict) -> None:
    """Create visualisations of income decile statistics."""
    fig = make_subplots(
        rows=1,
        cols=3,
        subplot_titles=(
            "Net income by decile (£bn)",
            "Tax by decile (£bn)",
            "Households by decile (millions)",
        ),
        specs=[[{"type": "bar"}, {"type": "bar"}, {"type": "bar"}]],
    )

    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["net_incomes"],
            marker_color="lightblue",
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["taxes"],
            marker_color="lightcoral",
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["counts"],
            marker_color="lightgreen",
        ),
        row=1,
        col=3,
    )

    fig.update_xaxes(title_text="Income decile", row=1, col=1)
    fig.update_xaxes(title_text="Income decile", row=1, col=2)
    fig.update_xaxes(title_text="Income decile", row=1, col=3)

    fig.update_layout(
        title_text="Household income and tax distribution",
        showlegend=False,
        height=400,
    )

    fig.show()


def main():
    """Main execution function."""
    print("Loading representative household data...")
    dataset = load_representative_data(year=2026)

    print("Running microsimulation...")
    simulation = run_simulation(dataset)

    print("Calculating statistics by income decile...")
    results = calculate_income_decile_statistics(simulation)

    print("\nResults summary:")
    total_net_income = sum(results["net_incomes"])
    total_tax = sum(results["taxes"])
    total_households = sum(results["counts"])

    print(f"Total net income: £{total_net_income:.1f}bn")
    print(f"Total tax: £{total_tax:.1f}bn")
    print(f"Total households: {total_households:.1f}m")
    print(
        f"Average effective tax rate: {total_tax / (total_net_income + total_tax) * 100:.1f}%"
    )

    print("\nGenerating visualisations...")
    visualise_results(results)


if __name__ == "__main__":
    main()

US income distribution

Loading enhanced CPS microdata, running a full microsimulation, and calculating statistics within income deciles.

income_distribution_us.py
"""Example: Plot US household income distribution using enhanced CPS microdata.

This script demonstrates:
1. Loading enhanced CPS representative household microdata
2. Running a full microsimulation to calculate household income and tax
3. Using Aggregate to calculate statistics within income deciles
4. Visualising the income distribution across the United States

Run: python examples/income_distribution_us.py
"""

import time
from pathlib import Path

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from policyengine.core import Simulation
from policyengine.outputs.aggregate import Aggregate, AggregateType
from policyengine.tax_benefit_models.us import (
    PolicyEngineUSDataset,
    us_latest,
)
from policyengine.utils.plotting import COLORS, format_fig


def load_representative_data(year: int = 2024) -> PolicyEngineUSDataset:
    """Load representative household microdata for a given year."""
    dataset_path = Path(__file__).parent / "data" / f"enhanced_cps_2024_year_{year}.h5"

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"Dataset not found at {dataset_path}. "
            "Run create_datasets() from policyengine.tax_benefit_models.us first."
        )

    dataset = PolicyEngineUSDataset(
        name=f"Enhanced CPS {year}",
        description=f"Representative household microdata for {year}",
        filepath=str(dataset_path),
        year=year,
    )
    dataset.load()
    return dataset


def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation:
    """Run a microsimulation on the dataset."""
    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=us_latest,
    )
    simulation.run()
    return simulation


def calculate_income_decile_statistics(simulation: Simulation) -> dict:
    """Calculate total income, tax, and benefits by income deciles."""
    start_time = time.time()
    deciles = [f"D{i}" for i in range(1, 11)]
    market_incomes = []
    taxes = []
    benefits = []
    net_incomes = []
    counts = []

    # Calculate household-level aggregates by decile
    print("Calculating main statistics by decile...")
    main_stats_start = time.time()
    for decile_num in range(1, 11):
        decile_start = time.time()

        # Market income
        pre_create = time.time()
        agg = Aggregate(
            simulation=simulation,
            variable="household_market_income",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile_num,
        )
        if decile_num == 1:
            print(f"    First Aggregate created ({time.time() - pre_create:.2f}s)")
        pre_run = time.time()
        agg.run()
        if decile_num == 1:
            print(f"    First Aggregate.run() complete ({time.time() - pre_run:.2f}s)")
        market_incomes.append(agg.result / 1e9)

        agg = Aggregate(
            simulation=simulation,
            variable="household_tax",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile_num,
        )
        agg.run()
        taxes.append(agg.result / 1e9)

        agg = Aggregate(
            simulation=simulation,
            variable="household_benefits",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile_num,
        )
        agg.run()
        benefits.append(agg.result / 1e9)

        agg = Aggregate(
            simulation=simulation,
            variable="household_net_income",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile_num,
        )
        agg.run()
        net_incomes.append(agg.result / 1e9)

        agg = Aggregate(
            simulation=simulation,
            variable="household_weight",
            aggregate_type=AggregateType.SUM,
            filter_variable="household_net_income",
            quantile=10,
            quantile_eq=decile_num,
        )
        agg.run()
        counts.append(agg.result / 1e6)

        print(f"  D{decile_num} complete ({time.time() - decile_start:.2f}s)")

    print(f"Main statistics complete ({time.time() - main_stats_start:.2f}s)")

    # Calculate individual benefit programs by decile
    benefit_programs_by_decile = {}

    # Person-level benefits (mapped to household for decile filtering)
    print("Calculating person-level benefit programs...")
    person_benefits_start = time.time()
    first_prog = True
    for prog in [
        "ssi",
        "social_security",
        "medicaid",
        "unemployment_compensation",
    ]:
        prog_start = time.time()
        prog_by_decile = []
        for decile_num in range(1, 11):
            if first_prog and decile_num == 1:
                pre_create = time.time()
            agg = Aggregate(
                simulation=simulation,
                variable=prog,
                entity="household",
                aggregate_type=AggregateType.SUM,
                filter_variable="household_net_income",
                quantile=10,
                quantile_eq=decile_num,
                debug_timing=first_prog and decile_num == 1,
            )
            if first_prog and decile_num == 1:
                print(
                    f"    First benefit Aggregate created ({time.time() - pre_create:.2f}s)"
                )
                pre_run = time.time()
            agg.run()
            if first_prog and decile_num == 1:
                print(
                    f"    First benefit Aggregate.run() complete ({time.time() - pre_run:.2f}s)"
                )
                first_prog = False
            prog_by_decile.append(agg.result / 1e9)
        benefit_programs_by_decile[prog] = prog_by_decile
        print(f"  {prog} complete ({time.time() - prog_start:.2f}s)")

    print(
        f"Person-level benefits complete ({time.time() - person_benefits_start:.2f}s)"
    )

    # SPM unit benefits (mapped to household for decile filtering)
    print("Calculating SPM unit benefit programs...")
    spm_benefits_start = time.time()
    for prog in ["snap", "tanf"]:
        prog_start = time.time()
        prog_by_decile = []
        for decile_num in range(1, 11):
            agg = Aggregate(
                simulation=simulation,
                variable=prog,
                entity="household",
                aggregate_type=AggregateType.SUM,
                filter_variable="household_net_income",
                quantile=10,
                quantile_eq=decile_num,
            )
            agg.run()
            prog_by_decile.append(agg.result / 1e9)
        benefit_programs_by_decile[prog] = prog_by_decile
        print(f"  {prog} complete ({time.time() - prog_start:.2f}s)")

    print(f"SPM benefits complete ({time.time() - spm_benefits_start:.2f}s)")

    # Tax unit benefits (mapped to household for decile filtering)
    print("Calculating tax unit benefit programs...")
    tax_benefits_start = time.time()
    for prog in ["eitc", "ctc"]:
        prog_start = time.time()
        prog_by_decile = []
        for decile_num in range(1, 11):
            agg = Aggregate(
                simulation=simulation,
                variable=prog,
                entity="household",
                aggregate_type=AggregateType.SUM,
                filter_variable="household_net_income",
                quantile=10,
                quantile_eq=decile_num,
            )
            agg.run()
            prog_by_decile.append(agg.result / 1e9)
        benefit_programs_by_decile[prog] = prog_by_decile
        print(f"  {prog} complete ({time.time() - prog_start:.2f}s)")

    print(f"Tax benefits complete ({time.time() - tax_benefits_start:.2f}s)")
    print(f"\nTotal statistics calculation time: {time.time() - start_time:.2f}s")

    return {
        "deciles": deciles,
        "market_incomes": market_incomes,
        "taxes": taxes,
        "benefits": benefits,
        "net_incomes": net_incomes,
        "counts": counts,
        "benefit_programs_by_decile": benefit_programs_by_decile,
    }


def visualise_results(results: dict) -> None:
    """Create visualisations of income distribution."""
    # Create overview figure
    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=(
            "Market income by decile ($bn)",
            "Tax by decile ($bn)",
            "Benefits by program and decile ($bn)",
            "Households by decile (millions)",
        ),
        specs=[
            [{"type": "bar"}, {"type": "bar"}],
            [{"type": "bar"}, {"type": "bar"}],
        ],
    )

    # Market income
    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["market_incomes"],
            marker_color=COLORS["primary"],
            name="Market income",
            showlegend=False,
        ),
        row=1,
        col=1,
    )

    # Tax
    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["taxes"],
            marker_color=COLORS["error"],
            name="Tax",
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    # Benefits by program (stacked) - with legend
    benefit_programs = [
        ("Social Security", "social_security", "#026AA2"),
        ("Medicaid", "medicaid", "#319795"),
        ("SNAP", "snap", "#22C55E"),
        ("EITC", "eitc", "#FEC601"),
        ("CTC", "ctc", "#1890FF"),
        ("SSI", "ssi", "#EF4444"),
        ("TANF", "tanf", "#667085"),
        ("Unemployment", "unemployment_compensation", "#101828"),
    ]

    for name, key, color in benefit_programs:
        if key in results["benefit_programs_by_decile"]:
            fig.add_trace(
                go.Bar(
                    x=results["deciles"],
                    y=results["benefit_programs_by_decile"][key],
                    name=name,
                    marker_color=color,
                    legendgroup="benefits",
                    showlegend=True,
                ),
                row=2,
                col=1,
            )

    # Household counts
    fig.add_trace(
        go.Bar(
            x=results["deciles"],
            y=results["counts"],
            marker_color=COLORS["info"],
            name="Households",
            showlegend=False,
        ),
        row=2,
        col=2,
    )

    fig.update_xaxes(title_text="Income decile", row=1, col=1)
    fig.update_xaxes(title_text="Income decile", row=1, col=2)
    fig.update_xaxes(title_text="Income decile", row=2, col=1)
    fig.update_xaxes(title_text="Income decile", row=2, col=2)

    # Apply PolicyEngine formatting
    format_fig(
        fig,
        title="US household income distribution (Enhanced CPS 2024)",
        show_legend=True,
        height=800,
        width=1400,
    )

    # Override legend position for subplot layout
    fig.update_layout(
        barmode="stack",
        legend=dict(
            orientation="v",
            yanchor="top",
            y=0.45,
            xanchor="left",
            x=0.52,
            bgcolor="white",
            bordercolor="#E5E7EB",
            borderwidth=1,
        ),
    )

    fig.show()


def main():
    """Main execution function."""
    print("Loading enhanced CPS representative household data...")
    dataset = load_representative_data(year=2024)

    print(
        f"Dataset loaded: {len(dataset.data.person):,} people, {len(dataset.data.household):,} households"
    )

    print("Running microsimulation...")
    simulation = run_simulation(dataset)

    print("Calculating statistics by income decile...")
    results = calculate_income_decile_statistics(simulation)

    print("\nResults summary:")
    total_market_income = sum(results["market_incomes"])
    total_tax = sum(results["taxes"])
    total_benefits = sum(results["benefits"])
    total_net_income = sum(results["net_incomes"])
    total_households = sum(results["counts"])

    print(f"Total market income: ${total_market_income:.1f}bn")
    print(f"Total tax: ${total_tax:.1f}bn")
    print(f"Total benefits: ${total_benefits:.1f}bn")
    print(f"Total net income: ${total_net_income:.1f}bn")
    print(f"Total households: {total_households:.1f}m")
    print(f"Average effective tax rate: {total_tax / total_market_income * 100:.1f}%")

    print("\nBenefit programs by decile:")
    benefit_programs = [
        ("Social Security", "social_security"),
        ("Medicaid", "medicaid"),
        ("SNAP", "snap"),
        ("EITC", "eitc"),
        ("CTC", "ctc"),
        ("SSI", "ssi"),
        ("TANF", "tanf"),
        ("Unemployment", "unemployment_compensation"),
    ]

    for name, key in benefit_programs:
        if key in results["benefit_programs_by_decile"]:
            total = sum(results["benefit_programs_by_decile"][key])
            print(f"\n  {name} (total: ${total:.1f}bn):")
            for i, decile in enumerate(results["deciles"]):
                value = results["benefit_programs_by_decile"][key][i]
                print(f"    {decile}: ${value:.1f}bn")

    print("\nGenerating visualisations...")
    visualise_results(results)


if __name__ == "__main__":
    main()

UK employment income variation

Creating a custom dataset with varied employment income, running a single simulation, and visualising benefit phase-outs.

employment_income_variation_uk.py
"""Example: Vary employment income and plot HBAI household net income.

This script demonstrates:
1. Creating a custom dataset with a single household template
2. Varying employment income from £0 to £100k
3. Running a single simulation for all variations
4. Using Aggregate with filters to extract results by employment income
5. Visualising the relationship between employment income and net income

IMPORTANT NOTES FOR CUSTOM DATASETS:
- Always set would_claim_* variables to True, otherwise benefits won't be claimed
  even if the household is eligible (they default to random/False)
- Always set disability variables explicitly (is_disabled_for_benefits, uc_limited_capability_for_WRA)
  to prevent random UC spikes from LCWRA element (£5,241/year extra if randomly assigned)
- Must include join keys: person_benunit_id, person_household_id in person data
- Required household fields: region, council_tax, rent, tenure_type
- Person-level variables are mapped to household level using weights

Run: python examples/employment_income_variation.py
"""

import tempfile
from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
from microdf import MicroDataFrame

from policyengine.core import Simulation
from policyengine.outputs.aggregate import Aggregate, AggregateType
from policyengine.tax_benefit_models.uk import (
    PolicyEngineUKDataset,
    UKYearData,
    uk_latest,
)
from policyengine.utils.plotting import COLORS, format_fig


def create_dataset_with_varied_employment_income(
    employment_incomes: list[float], year: int = 2026
) -> PolicyEngineUKDataset:
    """Create a dataset with one household template, varied by employment income.

    Each household is a single adult with 2 children, paying median UK rent.
    Employment income varies across households.
    """
    n_households = len(employment_incomes)
    n_households * 3  # 1 adult + 2 children per household

    # Create person data - one adult + 2 children per household
    person_ids = []
    benunit_ids = []
    household_ids = []
    ages = []
    employment_income = []
    person_weights = []
    is_disabled = []
    limited_capability = []

    person_id_counter = 0
    for hh_idx in range(n_households):
        # Adult
        person_ids.append(person_id_counter)
        benunit_ids.append(hh_idx)
        household_ids.append(hh_idx)
        ages.append(35)
        employment_income.append(employment_incomes[hh_idx])
        person_weights.append(1.0)
        is_disabled.append(False)
        limited_capability.append(False)
        person_id_counter += 1

        # Child 1 (age 8)
        person_ids.append(person_id_counter)
        benunit_ids.append(hh_idx)
        household_ids.append(hh_idx)
        ages.append(8)
        employment_income.append(0.0)
        person_weights.append(1.0)
        is_disabled.append(False)
        limited_capability.append(False)
        person_id_counter += 1

        # Child 2 (age 5)
        person_ids.append(person_id_counter)
        benunit_ids.append(hh_idx)
        household_ids.append(hh_idx)
        ages.append(5)
        employment_income.append(0.0)
        person_weights.append(1.0)
        is_disabled.append(False)
        limited_capability.append(False)
        person_id_counter += 1

    person_data = {
        "person_id": person_ids,
        "person_benunit_id": benunit_ids,
        "person_household_id": household_ids,
        "age": ages,
        "employment_income": employment_income,
        "person_weight": person_weights,
        "is_disabled_for_benefits": is_disabled,
        "uc_limited_capability_for_WRA": limited_capability,
    }

    # Create benunit data - one per household
    benunit_data = {
        "benunit_id": list(range(n_households)),
        "benunit_weight": [1.0] * n_households,
        # Would claim variables - MUST set to True or benefits won't be claimed!
        "would_claim_uc": [True] * n_households,
        "would_claim_WTC": [True] * n_households,
        "would_claim_CTC": [True] * n_households,
        "would_claim_IS": [True] * n_households,
        "would_claim_pc": [True] * n_households,
        "would_claim_child_benefit": [True] * n_households,
        "would_claim_housing_benefit": [True] * n_households,
    }

    # Create household data - one per employment income level
    median_annual_rent = 850 * 12  # £850/month = £10,200/year (median UK rent)
    household_data = {
        "household_id": list(range(n_households)),
        "household_weight": [1.0] * n_households,
        "region": ["LONDON"] * n_households,  # Required by policyengine-uk
        "council_tax": [0.0] * n_households,  # Simplified - no council tax
        "rent": [median_annual_rent] * n_households,  # Median UK rent
        "tenure_type": ["RENT_PRIVATELY"] * n_households,  # Required for uprating
    }

    # Create MicroDataFrames
    person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight")
    benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight")
    household_df = MicroDataFrame(
        pd.DataFrame(household_data), weights="household_weight"
    )

    # Create temporary file
    tmpdir = tempfile.mkdtemp()
    filepath = str(Path(tmpdir) / "employment_income_variation.h5")

    # Create dataset
    dataset = PolicyEngineUKDataset(
        name="Employment income variation",
        description="Single adult household with varying employment income",
        filepath=filepath,
        year=year,
        data=UKYearData(
            person=person_df,
            benunit=benunit_df,
            household=household_df,
        ),
    )

    return dataset


def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation:
    """Run a single simulation for all employment income variations."""

    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=uk_latest,
    )
    simulation.run()
    return simulation


def extract_results_by_employment_income(
    simulation: Simulation, employment_incomes: list[float]
) -> dict:
    """Extract HBAI household net income and components for each employment income level.

    Uses Aggregate with filters to extract specific households.
    """
    hbai_net_income = []
    household_benefits = []
    household_tax = []
    employment_income_hh = []

    # Individual benefits
    universal_credit = []
    child_benefit = []
    working_tax_credit = []
    child_tax_credit = []
    pension_credit = []
    income_support = []

    for hh_idx, emp_income in enumerate(employment_incomes):
        # Get HBAI household net income
        agg = Aggregate(
            simulation=simulation,
            variable="hbai_household_net_income",
            aggregate_type=AggregateType.MEAN,
            filter_variable="household_id",
            filter_variable_eq=hh_idx,
            entity="household",
        )
        agg.run()
        hbai_net_income.append(agg.result)

        # Get household benefits
        agg = Aggregate(
            simulation=simulation,
            variable="household_benefits",
            aggregate_type=AggregateType.MEAN,
            filter_variable="household_id",
            filter_variable_eq=hh_idx,
            entity="household",
        )
        agg.run()
        household_benefits.append(agg.result)

        # Get individual benefits (at benunit level, but we have 1:1 benunit:household mapping)
        for benefit_name, benefit_list in [
            ("universal_credit", universal_credit),
            ("child_benefit", child_benefit),
            ("working_tax_credit", working_tax_credit),
            ("child_tax_credit", child_tax_credit),
            ("pension_credit", pension_credit),
            ("income_support", income_support),
        ]:
            agg = Aggregate(
                simulation=simulation,
                variable=benefit_name,
                aggregate_type=AggregateType.MEAN,
                filter_variable="benunit_id",
                filter_variable_eq=hh_idx,
                entity="benunit",
            )
            agg.run()
            benefit_list.append(agg.result)

        # Get household tax
        agg = Aggregate(
            simulation=simulation,
            variable="household_tax",
            aggregate_type=AggregateType.MEAN,
            filter_variable="household_id",
            filter_variable_eq=hh_idx,
            entity="household",
        )
        agg.run()
        household_tax.append(agg.result)

        # Employment income at household level (just use the input value)
        employment_income_hh.append(emp_income)

    return {
        "employment_income": employment_incomes,
        "hbai_household_net_income": hbai_net_income,
        "household_benefits": household_benefits,
        "household_tax": household_tax,
        "employment_income_hh": employment_income_hh,
        "universal_credit": universal_credit,
        "child_benefit": child_benefit,
        "working_tax_credit": working_tax_credit,
        "child_tax_credit": child_tax_credit,
        "pension_credit": pension_credit,
        "income_support": income_support,
    }


def visualise_results(results: dict) -> None:
    """Create a stacked area chart showing income composition."""
    fig = go.Figure()

    # Calculate net employment income (employment income minus tax)
    net_employment = [
        emp - tax
        for emp, tax in zip(results["employment_income_hh"], results["household_tax"])
    ]

    # Stack benefits and income components using PolicyEngine colors
    components = [
        ("Net employment income", net_employment, COLORS["primary"]),
        (
            "Universal Credit",
            results["universal_credit"],
            COLORS["blue_secondary"],
        ),
        ("Working Tax Credit", results["working_tax_credit"], COLORS["info"]),
        ("Child Tax Credit", results["child_tax_credit"], COLORS["success"]),
        ("Child Benefit", results["child_benefit"], COLORS["warning"]),
        ("Pension Credit", results["pension_credit"], COLORS["gray"]),
        ("Income Support", results["income_support"], COLORS["gray_dark"]),
    ]

    for name, values, color in components:
        fig.add_trace(
            go.Scatter(
                x=results["employment_income"],
                y=values,
                name=name,
                mode="lines",
                line=dict(width=0.5, color=color),
                stackgroup="one",
                fillcolor=color,
            )
        )

    # Apply PolicyEngine styling
    format_fig(
        fig,
        title="Household net income composition by employment income",
        xaxis_title="Employment income (£)",
        yaxis_title="Net income (£)",
        show_legend=True,
        height=700,
        width=1200,
    )

    fig.show()


def main():
    """Main execution function."""
    # Create employment income range from £0 to £100k
    # Using smaller intervals at lower incomes where the relationship is more interesting
    employment_incomes = (
        list(range(0, 20000, 1000))  # £0 to £20k in £1k steps
        + list(range(20000, 50000, 2500))  # £20k to £50k in £2.5k steps
        + list(range(50000, 100001, 5000))  # £50k to £100k in £5k steps
    )

    print(
        f"Creating dataset with {len(employment_incomes)} employment income variations..."
    )
    dataset = create_dataset_with_varied_employment_income(employment_incomes)

    print("Running simulation (single run for all variations)...")
    simulation = run_simulation(dataset)

    print("Extracting results using aggregate filters...")
    results = extract_results_by_employment_income(simulation, employment_incomes)

    print("\nSample results:")
    for emp_inc in [0, 25000, 50000, 100000]:
        idx = employment_incomes.index(emp_inc) if emp_inc in employment_incomes else -1
        if idx >= 0:
            print(
                f"  Employment income £{emp_inc:,}: HBAI net income £{results['hbai_household_net_income'][idx]:,.0f}"
            )

    print("\nGenerating visualisation...")
    visualise_results(results)


if __name__ == "__main__":
    main()

US employment income variation

Same approach as the UK version, varying employment income from 0to0 to 200k and plotting household net income.

employment_income_variation_us.py
"""Example: Vary employment income and plot household net income (US).

This script demonstrates:
1. Creating a custom dataset with a single household template
2. Varying employment income from $0 to $200k
3. Running a single simulation for all variations
4. Using Aggregate with filters to extract results by employment income
5. Visualising the relationship between employment income and net income

Run: python examples/employment_income_variation_us.py
"""

import tempfile
from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
from microdf import MicroDataFrame

from policyengine.core import Simulation
from policyengine.tax_benefit_models.us import (
    PolicyEngineUSDataset,
    USYearData,
    us_latest,
)
from policyengine.utils.plotting import COLORS, format_fig


def create_dataset_with_varied_employment_income(
    employment_incomes: list[float], year: int = 2024
) -> PolicyEngineUSDataset:
    """Create a dataset with one household template, varied by employment income.

    Each household is a single adult with 2 children.
    Employment income varies across households.
    """
    n_households = len(employment_incomes)
    n_households * 3  # 1 adult + 2 children per household

    # Create person data - one adult + 2 children per household
    person_ids = []
    household_ids = []
    marital_unit_ids = []
    family_ids = []
    spm_unit_ids = []
    tax_unit_ids = []
    ages = []
    employment_income = []
    person_weights = []

    person_id_counter = 0
    for hh_idx in range(n_households):
        # Adult
        person_ids.append(person_id_counter)
        household_ids.append(hh_idx)
        marital_unit_ids.append(hh_idx)
        family_ids.append(hh_idx)
        spm_unit_ids.append(hh_idx)
        tax_unit_ids.append(hh_idx)
        ages.append(35)
        employment_income.append(employment_incomes[hh_idx])
        person_weights.append(1000.0)
        person_id_counter += 1

        # Child 1 (age 8)
        person_ids.append(person_id_counter)
        household_ids.append(hh_idx)
        marital_unit_ids.append(hh_idx)
        family_ids.append(hh_idx)
        spm_unit_ids.append(hh_idx)
        tax_unit_ids.append(hh_idx)
        ages.append(8)
        employment_income.append(0.0)
        person_weights.append(1000.0)
        person_id_counter += 1

        # Child 2 (age 5)
        person_ids.append(person_id_counter)
        household_ids.append(hh_idx)
        marital_unit_ids.append(hh_idx)
        family_ids.append(hh_idx)
        spm_unit_ids.append(hh_idx)
        tax_unit_ids.append(hh_idx)
        ages.append(5)
        employment_income.append(0.0)
        person_weights.append(1000.0)
        person_id_counter += 1

    person_data = {
        "person_id": person_ids,
        "household_id": household_ids,
        "marital_unit_id": marital_unit_ids,
        "family_id": family_ids,
        "spm_unit_id": spm_unit_ids,
        "tax_unit_id": tax_unit_ids,
        "age": ages,
        "employment_income": employment_income,
        "person_weight": person_weights,
    }

    # Create household data
    household_data = {
        "household_id": list(range(n_households)),
        "state_name": ["CA"] * n_households,  # California
        "household_weight": [1000.0] * n_households,
    }

    # Create group entity data
    marital_unit_data = {
        "marital_unit_id": list(range(n_households)),
        "marital_unit_weight": [1000.0] * n_households,
    }

    family_data = {
        "family_id": list(range(n_households)),
        "family_weight": [1000.0] * n_households,
    }

    spm_unit_data = {
        "spm_unit_id": list(range(n_households)),
        "spm_unit_weight": [1000.0] * n_households,
    }

    tax_unit_data = {
        "tax_unit_id": list(range(n_households)),
        "tax_unit_weight": [1000.0] * n_households,
    }

    # Create MicroDataFrames
    person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight")
    household_df = MicroDataFrame(
        pd.DataFrame(household_data), weights="household_weight"
    )
    marital_unit_df = MicroDataFrame(
        pd.DataFrame(marital_unit_data), weights="marital_unit_weight"
    )
    family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight")
    spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight")
    tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight")

    # Create temporary file
    tmpdir = tempfile.mkdtemp()
    filepath = str(Path(tmpdir) / "employment_income_variation_us.h5")

    # Create dataset
    dataset = PolicyEngineUSDataset(
        name="Employment income variation (US)",
        description="Single adult household with 2 children, varying employment income",
        filepath=filepath,
        year=year,
        data=USYearData(
            person=person_df,
            household=household_df,
            marital_unit=marital_unit_df,
            family=family_df,
            spm_unit=spm_unit_df,
            tax_unit=tax_unit_df,
        ),
    )

    return dataset


def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation:
    """Run a single simulation for all employment income variations."""

    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=us_latest,
    )
    simulation.run()
    return simulation


def extract_results_by_employment_income(
    simulation: Simulation, employment_incomes: list[float]
) -> dict:
    """Extract household net income and components for each employment income level.

    Directly accesses output data by row index since we have one household per income level.
    """
    import pandas as pd

    # Get output data
    household_df = pd.DataFrame(simulation.output_dataset.data.household)
    spm_unit_df = pd.DataFrame(simulation.output_dataset.data.spm_unit)
    tax_unit_df = pd.DataFrame(simulation.output_dataset.data.tax_unit)

    # Extract results (one row per household/spm_unit/tax_unit)
    household_net_income = household_df["household_net_income"].tolist()
    household_benefits = household_df["household_benefits"].tolist()
    household_tax = household_df["household_tax"].tolist()

    snap = spm_unit_df["snap"].tolist()
    tanf = spm_unit_df["tanf"].tolist()

    eitc = tax_unit_df["eitc"].tolist()
    ctc = tax_unit_df["ctc"].tolist()

    employment_income_hh = employment_incomes

    return {
        "employment_income": employment_incomes,
        "household_net_income": household_net_income,
        "household_benefits": household_benefits,
        "household_tax": household_tax,
        "employment_income_hh": employment_income_hh,
        "snap": snap,
        "tanf": tanf,
        "eitc": eitc,
        "ctc": ctc,
    }


def visualise_results(results: dict) -> None:
    """Create a stacked area chart showing income composition."""
    fig = go.Figure()

    # Calculate net employment income (employment income minus tax)
    net_employment = [
        emp - tax
        for emp, tax in zip(results["employment_income_hh"], results["household_tax"])
    ]

    # Stack benefits and income components using PolicyEngine colors
    components = [
        ("Net employment income", net_employment, COLORS["primary"]),
        ("SNAP", results["snap"], COLORS["blue_secondary"]),
        ("TANF", results["tanf"], COLORS["info"]),
        ("EITC", results["eitc"], COLORS["success"]),
        ("CTC", results["ctc"], COLORS["warning"]),
    ]

    for name, values, color in components:
        fig.add_trace(
            go.Scatter(
                x=results["employment_income"],
                y=values,
                name=name,
                mode="lines",
                line=dict(width=0.5, color=color),
                stackgroup="one",
                fillcolor=color,
            )
        )

    # Apply PolicyEngine styling
    format_fig(
        fig,
        title="Household net income composition by employment income",
        xaxis_title="Employment income ($)",
        yaxis_title="Net income ($)",
        show_legend=True,
        height=700,
        width=1200,
    )

    fig.show()


def main():
    """Main execution function."""
    # Create employment income range from $0 to $200k
    # Using smaller intervals at lower incomes where the relationship is more interesting
    employment_incomes = (
        list(range(0, 40000, 2000))  # $0 to $40k in $2k steps
        + list(range(40000, 100000, 5000))  # $40k to $100k in $5k steps
        + list(range(100000, 200001, 10000))  # $100k to $200k in $10k steps
    )

    print(
        f"Creating dataset with {len(employment_incomes)} employment income variations..."
    )
    dataset = create_dataset_with_varied_employment_income(employment_incomes)

    print("Running simulation (single run for all variations)...")
    simulation = run_simulation(dataset)

    print("Extracting results using aggregate filters...")
    results = extract_results_by_employment_income(simulation, employment_incomes)

    print("\nSample results:")
    for emp_inc in [0, 50000, 100000, 200000]:
        idx = employment_incomes.index(emp_inc) if emp_inc in employment_incomes else -1
        if idx >= 0:
            print(
                f"  Employment income ${emp_inc:,}: household net income ${results['household_net_income'][idx]:,.0f}"
            )

    print("\nGenerating visualisation...")
    visualise_results(results)


if __name__ == "__main__":
    main()

Household impact calculation

Using calculate_household_impact() to compute taxes and benefits for individual custom households (both UK and US).

household_impact_example.py
"""Example: Calculate household tax and benefit impacts.

This script demonstrates using calculate_household_impact for both UK and US
to compute taxes and benefits for custom households.

Run: python examples/household_impact_example.py
"""

from policyengine.tax_benefit_models.uk import (
    UKHouseholdInput,
)
from policyengine.tax_benefit_models.uk import (
    calculate_household_impact as calculate_uk_impact,
)
from policyengine.tax_benefit_models.us import (
    USHouseholdInput,
)
from policyengine.tax_benefit_models.us import (
    calculate_household_impact as calculate_us_impact,
)


def uk_example():
    """UK household impact example."""
    print("=" * 60)
    print("UK HOUSEHOLD IMPACT")
    print("=" * 60)

    # Single adult earning £50,000
    household = UKHouseholdInput(
        people=[{"age": 35, "employment_income": 50_000}],
        year=2026,
    )
    result = calculate_uk_impact(household)

    print("\nSingle adult, £50k income:")
    print(f"  Net income: £{result.household['hbai_household_net_income']:,.0f}")
    print(f"  Income tax: £{result.person[0]['income_tax']:,.0f}")
    print(f"  National Insurance: £{result.person[0]['national_insurance']:,.0f}")
    print(f"  Total tax: £{result.household['household_tax']:,.0f}")

    # Family with two children, £30k income, renting
    household = UKHouseholdInput(
        people=[
            {"age": 35, "employment_income": 30_000},
            {"age": 33},
            {"age": 8},
            {"age": 5},
        ],
        benunit={
            "would_claim_uc": True,
            "would_claim_child_benefit": True,
        },
        household={
            "rent": 12_000,  # £1k/month
            "region": "NORTH_WEST",
        },
        year=2026,
    )
    result = calculate_uk_impact(household)

    print("\nFamily (2 adults, 2 children), £30k income, renting:")
    print(f"  Net income: £{result.household['hbai_household_net_income']:,.0f}")
    print(f"  Income tax: £{result.person[0]['income_tax']:,.0f}")
    print(f"  Child benefit: £{result.benunit[0]['child_benefit']:,.0f}")
    print(f"  Universal credit: £{result.benunit[0]['universal_credit']:,.0f}")
    print(f"  Total benefits: £{result.household['household_benefits']:,.0f}")


def us_example():
    """US household impact example."""
    print("\n" + "=" * 60)
    print("US HOUSEHOLD IMPACT")
    print("=" * 60)

    # Single adult earning $50,000
    household = USHouseholdInput(
        people=[{"age": 35, "employment_income": 50_000, "is_tax_unit_head": True}],
        tax_unit={"filing_status": "SINGLE"},
        household={"state_code_str": "CA"},
        year=2024,
    )
    result = calculate_us_impact(household)

    print("\nSingle adult, $50k income (California):")
    print(f"  Net income: ${result.household['household_net_income']:,.0f}")
    print(f"  Income tax: ${result.tax_unit[0]['income_tax']:,.0f}")
    print(f"  Payroll tax: ${result.tax_unit[0]['employee_payroll_tax']:,.0f}")

    # Married couple with children, lower income
    household = USHouseholdInput(
        people=[
            {"age": 35, "employment_income": 40_000, "is_tax_unit_head": True},
            {"age": 33, "is_tax_unit_spouse": True},
            {"age": 8, "is_tax_unit_dependent": True},
            {"age": 5, "is_tax_unit_dependent": True},
        ],
        tax_unit={"filing_status": "JOINT"},
        household={"state_code_str": "TX"},
        year=2024,
    )
    result = calculate_us_impact(household)

    print("\nMarried couple with 2 children, $40k income (Texas):")
    print(f"  Net income: ${result.household['household_net_income']:,.0f}")
    print(f"  Federal income tax: ${result.tax_unit[0]['income_tax']:,.0f}")
    print(f"  EITC: ${result.tax_unit[0]['eitc']:,.0f}")
    print(f"  Child tax credit: ${result.tax_unit[0]['ctc']:,.0f}")
    print(f"  SNAP: ${result.spm_unit[0]['snap']:,.0f}")


def main():
    uk_example()
    us_example()
    print("\n" + "=" * 60)
    print("Done!")


if __name__ == "__main__":
    main()

Simulation performance

Benchmarking how simulation.run() scales with dataset size.

speedtest_us_simulation.py
"""Speedtest: US simulation performance with different dataset sizes.

This script tests how simulation.run() performance scales with dataset size
by running simulations on random subsets of households.
"""

import time
from pathlib import Path

import pandas as pd
from microdf import MicroDataFrame

from policyengine.core import Simulation
from policyengine.tax_benefit_models.us import (
    PolicyEngineUSDataset,
    USYearData,
    us_latest,
)


def create_subset_dataset(
    original_dataset: PolicyEngineUSDataset, n_households: int
) -> PolicyEngineUSDataset:
    """Create a random subset of the dataset with n_households and reindexed entity IDs."""
    # Get original data
    household_df = pd.DataFrame(original_dataset.data.household).copy()
    person_df = pd.DataFrame(original_dataset.data.person).copy()
    marital_unit_df = pd.DataFrame(original_dataset.data.marital_unit).copy()
    family_df = pd.DataFrame(original_dataset.data.family).copy()
    spm_unit_df = pd.DataFrame(original_dataset.data.spm_unit).copy()
    tax_unit_df = pd.DataFrame(original_dataset.data.tax_unit).copy()

    # Sample random households (use n as seed to get different samples for different sizes)
    sampled_households = household_df.sample(
        n=n_households, random_state=n_households
    ).copy()
    sampled_household_ids = set(sampled_households["household_id"])

    # Determine column naming convention
    household_id_col = (
        "person_household_id"
        if "person_household_id" in person_df.columns
        else "household_id"
    )
    marital_unit_id_col = (
        "person_marital_unit_id"
        if "person_marital_unit_id" in person_df.columns
        else "marital_unit_id"
    )
    family_id_col = (
        "person_family_id" if "person_family_id" in person_df.columns else "family_id"
    )
    spm_unit_id_col = (
        "person_spm_unit_id"
        if "person_spm_unit_id" in person_df.columns
        else "spm_unit_id"
    )
    tax_unit_id_col = (
        "person_tax_unit_id"
        if "person_tax_unit_id" in person_df.columns
        else "tax_unit_id"
    )

    # Filter person table to only include people in sampled households
    sampled_person = person_df[
        person_df[household_id_col].isin(sampled_household_ids)
    ].copy()

    # Get IDs of group entities that have members in sampled households
    sampled_marital_unit_ids = set(sampled_person[marital_unit_id_col].unique())
    sampled_family_ids = set(sampled_person[family_id_col].unique())
    sampled_spm_unit_ids = set(sampled_person[spm_unit_id_col].unique())
    sampled_tax_unit_ids = set(sampled_person[tax_unit_id_col].unique())

    # Filter group entity tables
    sampled_marital_unit = marital_unit_df[
        marital_unit_df["marital_unit_id"].isin(sampled_marital_unit_ids)
    ].copy()
    sampled_family = family_df[family_df["family_id"].isin(sampled_family_ids)].copy()
    sampled_spm_unit = spm_unit_df[
        spm_unit_df["spm_unit_id"].isin(sampled_spm_unit_ids)
    ].copy()
    sampled_tax_unit = tax_unit_df[
        tax_unit_df["tax_unit_id"].isin(sampled_tax_unit_ids)
    ].copy()

    # Create ID mappings to reindex to contiguous integers starting from 0
    household_id_map = {
        old_id: new_id for new_id, old_id in enumerate(sorted(sampled_household_ids))
    }
    marital_unit_id_map = {
        old_id: new_id for new_id, old_id in enumerate(sorted(sampled_marital_unit_ids))
    }
    family_id_map = {
        old_id: new_id for new_id, old_id in enumerate(sorted(sampled_family_ids))
    }
    spm_unit_id_map = {
        old_id: new_id for new_id, old_id in enumerate(sorted(sampled_spm_unit_ids))
    }
    tax_unit_id_map = {
        old_id: new_id for new_id, old_id in enumerate(sorted(sampled_tax_unit_ids))
    }
    person_id_map = {
        old_id: new_id
        for new_id, old_id in enumerate(sorted(sampled_person["person_id"]))
    }

    # Reindex all entity IDs in household table
    sampled_households["household_id"] = sampled_households["household_id"].map(
        household_id_map
    )

    # Reindex all entity IDs in person table
    sampled_person["person_id"] = sampled_person["person_id"].map(person_id_map)
    sampled_person[household_id_col] = sampled_person[household_id_col].map(
        household_id_map
    )
    sampled_person[marital_unit_id_col] = sampled_person[marital_unit_id_col].map(
        marital_unit_id_map
    )
    sampled_person[family_id_col] = sampled_person[family_id_col].map(family_id_map)
    sampled_person[spm_unit_id_col] = sampled_person[spm_unit_id_col].map(
        spm_unit_id_map
    )
    sampled_person[tax_unit_id_col] = sampled_person[tax_unit_id_col].map(
        tax_unit_id_map
    )

    # Reindex group entity tables
    sampled_marital_unit["marital_unit_id"] = sampled_marital_unit[
        "marital_unit_id"
    ].map(marital_unit_id_map)
    sampled_family["family_id"] = sampled_family["family_id"].map(family_id_map)
    sampled_spm_unit["spm_unit_id"] = sampled_spm_unit["spm_unit_id"].map(
        spm_unit_id_map
    )
    sampled_tax_unit["tax_unit_id"] = sampled_tax_unit["tax_unit_id"].map(
        tax_unit_id_map
    )

    # Sort by ID to ensure proper ordering
    sampled_households = sampled_households.sort_values("household_id").reset_index(
        drop=True
    )
    sampled_person = sampled_person.sort_values("person_id").reset_index(drop=True)
    sampled_marital_unit = sampled_marital_unit.sort_values(
        "marital_unit_id"
    ).reset_index(drop=True)
    sampled_family = sampled_family.sort_values("family_id").reset_index(drop=True)
    sampled_spm_unit = sampled_spm_unit.sort_values("spm_unit_id").reset_index(
        drop=True
    )
    sampled_tax_unit = sampled_tax_unit.sort_values("tax_unit_id").reset_index(
        drop=True
    )

    # Create new dataset
    subset_dataset = PolicyEngineUSDataset(
        name=f"Subset {n_households} households",
        description=f"Random subset of {n_households} households",
        filepath=f"./data/subset_{n_households}_households.h5",
        year=original_dataset.year,
        data=USYearData(
            person=MicroDataFrame(sampled_person, weights="person_weight"),
            household=MicroDataFrame(sampled_households, weights="household_weight"),
            marital_unit=MicroDataFrame(
                sampled_marital_unit, weights="marital_unit_weight"
            ),
            family=MicroDataFrame(sampled_family, weights="family_weight"),
            spm_unit=MicroDataFrame(sampled_spm_unit, weights="spm_unit_weight"),
            tax_unit=MicroDataFrame(sampled_tax_unit, weights="tax_unit_weight"),
        ),
    )

    return subset_dataset


def speedtest_simulation(dataset: PolicyEngineUSDataset) -> float:
    """Run simulation and return execution time in seconds."""
    simulation = Simulation(
        dataset=dataset,
        tax_benefit_model_version=us_latest,
    )

    start_time = time.time()
    simulation.run()
    end_time = time.time()

    return end_time - start_time


def main():
    print("Loading full enhanced CPS dataset...")
    dataset_path = Path(__file__).parent / "data" / "enhanced_cps_2024_year_2024.h5"

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"Dataset not found at {dataset_path}. "
            "Run create_datasets() from policyengine.tax_benefit_models.us first."
        )

    full_dataset = PolicyEngineUSDataset(
        name="Enhanced CPS 2024",
        description="Full enhanced CPS dataset",
        filepath=str(dataset_path),
        year=2024,
    )
    full_dataset.load()

    total_households = len(full_dataset.data.household)
    print(f"Full dataset: {total_households:,} households")

    # Test different subset sizes
    test_sizes = [
        100,
        500,
        1000,
        2500,
        5000,
        10000,
        21532,
    ]  # Last is full size

    results = []

    for n_households in test_sizes:
        if n_households > total_households:
            continue

        print(f"\nTesting {n_households:,} households...")

        if n_households == total_households:
            subset = full_dataset
        else:
            subset = create_subset_dataset(full_dataset, n_households)

        n_people = len(subset.data.person)
        print(f"  {n_people:,} people in subset")

        duration = speedtest_simulation(subset)
        print(f"  Simulation completed in {duration:.2f}s")

        results.append(
            {
                "households": n_households,
                "people": n_people,
                "duration_seconds": duration,
                "households_per_second": n_households / duration,
            }
        )

    print("\n" + "=" * 60)
    print("SPEEDTEST RESULTS")
    print("=" * 60)
    print(f"{'Households':<12} {'People':<10} {'Duration':<12} {'HH/sec':<10}")
    print("-" * 60)

    for result in results:
        print(
            f"{result['households']:<12,} {result['people']:<10,} "
            f"{result['duration_seconds']:<12.2f} {result['households_per_second']:<10.1f}"
        )

    # Calculate scaling characteristics
    print("\n" + "=" * 60)
    print("SCALING ANALYSIS")
    print("=" * 60)

    if len(results) >= 2:
        # Compare first and last results
        first = results[0]
        last = results[-1]

        size_ratio = last["households"] / first["households"]
        time_ratio = last["duration_seconds"] / first["duration_seconds"]

        print(f"Dataset size increased {size_ratio:.1f}x")
        print(f"Simulation time increased {time_ratio:.1f}x")

        if time_ratio < size_ratio * 1.2:
            print("Scaling: approximately linear or better")
        elif time_ratio < size_ratio * 2:
            print("Scaling: slightly worse than linear")
        else:
            print("Scaling: significantly worse than linear")


if __name__ == "__main__":
    main()