Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Benchmarking methods

This notebook provides a guide to benchmarking different imputation methods using Microimpute through two example data sets.

With sklearn’s Diabetes data set

For data set details refer here.

from typing import List, Type

import pandas as pd

from microimpute.comparisons import *
from microimpute.config import RANDOM_STATE
from microimpute.models import *
from microimpute.visualizations import method_comparison_results
from microimpute.evaluations import *
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

import warnings

warnings.filterwarnings("ignore")

# 1. Prepare data
diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
X_train, X_test = train_test_split(
    df, test_size=0.2, random_state=RANDOM_STATE
)

predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = ["s1", "s4"]  # Numerical variables

Y_test: pd.DataFrame = X_test[imputed_variables]

# 2. Run imputation methods
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg, Matching]
method_imputations = get_imputations(
    model_classes, X_train, X_test, predictors, imputed_variables
)

# 3. Compare imputation methods using unified metrics function
# The function automatically detects that these are numerical variables and uses quantile loss
loss_comparison_df = compare_metrics(
    Y_test, method_imputations, imputed_variables
)

# 4. Plot results - filter for quantile loss metrics only
quantile_loss_df = loss_comparison_df[loss_comparison_df['Metric'] == 'quantile_loss']
comparison_viz = method_comparison_results(
    data=quantile_loss_df,
    metric="quantile_loss",
    data_format="long",
)
fig = comparison_viz.plot(
    title="Method Comparison on Diabetes Dataset (Numerical Variables)",
)
fig.show()

# Display summary statistics
print("Summary of quantile loss by method:")
summary = quantile_loss_df[quantile_loss_df['Imputed Variable'] == 'mean_quantile_loss'].groupby('Method')['Loss'].mean()
print(summary.sort_values())

# Evaluate sensitivity to predictors
leave_one_out_results = leave_one_out_analysis(
    df,
    predictors,
    imputed_variables,
    QRF
)

predictor_inclusion_results = progressive_predictor_inclusion(
    df,
    predictors,
    imputed_variables,
    QRF
)

print(f"Optimal subset: {predictor_inclusion_results['optimal_subset']}")
print(f"Optimal loss: {predictor_inclusion_results['optimal_loss']}")

# For step-by-step details:
print(predictor_inclusion_results['results_df'])
Loading...
Summary of quantile loss by method:
Method
OLS         0.012658
QuantReg    0.012670
QRF         0.016240
Matching    0.022570
Name: Loss, dtype: float64
Loading...
Loading...
Optimal subset: ['bp', 'age', 'bmi', 'sex']
Optimal loss: 0.016242569699135692
   step predictor_added  predictors_included  avg_quantile_loss  avg_log_loss  \
0     1              bp                 [bp]           0.017815             0   
1     2             age            [bp, age]           0.017225             0   
2     3             bmi       [bp, age, bmi]           0.016578             0   
3     4             sex  [bp, age, bmi, sex]           0.016243             0   

   cumulative_improvement  marginal_improvement  
0                0.000000              0.000000  
1                0.000590              0.000590  
2                0.001236              0.000647  
3                0.001572              0.000336  

With the US Federal Reserve Board’s Survey of Consumer Finances

For data set details refer here.

# On the SCF Dataset

from typing import List, Type, Optional, Union

import io
import logging
import pandas as pd
from pydantic import validate_call
import requests
import zipfile

from microimpute.comparisons import *
from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG, VALID_YEARS
from microimpute.models import *
from microimpute.utils.data import preprocess_data
from microimpute.evaluations import *

import warnings

warnings.filterwarnings("ignore")

logger = logging.getLogger(__name__)

# 1. Prepare data
@validate_call(config=VALIDATE_CONFIG)
def load_scf(
    years: Optional[Union[int, List[int]]] = None,
    columns: Optional[List[str]] = None,
) -> pd.DataFrame:
    """Load Survey of Consumer Finances data for specified years and columns.

    Args:
        years: Year or list of years to load data for.
        columns: List of column names to load.

    Returns:
        DataFrame containing the requested data.

    Raises:
        ValueError: If no Stata files are found in the downloaded zip
            or invalid parameters
        RuntimeError: If there's a network error or a problem processing
            the downloaded data
    """
    def scf_url(year: int) -> str:
        """Return the URL of the SCF summary microdata zip file for a year."""

        if year not in VALID_YEARS:
            logger.error(
                f"Invalid SCF year: {year}. Valid years are {VALID_YEARS}"
            )
            raise

        url = f"https://www.federalreserve.gov/econres/files/scfp{year}s.zip"
        return url

    logger.info(f"Loading SCF data with years={years}")

    try:
        # Identify years for download
        if years is None:
            years = VALID_YEARS
            logger.warning(f"Using default years: {years}")

        if isinstance(years, int):
            years = [years]

        all_data: List[pd.DataFrame] = []

        for year in years:
            logger.info(f"Processing data for year {year}")
            try:
                # Download zip file
                logger.debug(f"Downloading SCF data for year {year}")
                url = scf_url(year)
                try:
                    response = requests.get(url, timeout=60)
                    response.raise_for_status()  # Raise an error for bad responses
                except requests.exceptions.RequestException as e:
                    logger.error(
                        f"Network error downloading SCF data for year {year}: {str(e)}"
                    )
                    raise

                # Process zip file
                z = zipfile.ZipFile(io.BytesIO(response.content))
                # Find the .dta file in the zip
                dta_files: List[str] = [
                    f for f in z.namelist() if f.endswith(".dta")
                ]
                if not dta_files:
                    logger.error(
                        f"No Stata files found in zip for year {year}"
                    )
                    raise

                # Read the Stata file
                try:
                    logger.debug(f"Reading Stata file: {dta_files[0]}")
                    with z.open(dta_files[0]) as f:
                        df = pd.read_stata(
                            io.BytesIO(f.read()), columns=columns
                        )
                        logger.debug(
                            f"Read DataFrame with shape {df.shape}"
                        )
                except Exception as e:
                    logger.error(
                        f"Error reading Stata file for year {year}: {str(e)}"
                    )
                    raise

                # Add year column
                df["year"] = year
                logger.info(
                    f"Successfully processed data for year {year}, shape: {df.shape}"
                )
                all_data.append(df)

            except Exception as e:
                logger.error(f"Error processing year {year}: {str(e)}")
                raise

        # Combine all years
        logger.debug(f"Combining data from {len(all_data)} years")
        if len(all_data) > 1:
            result = pd.concat(all_data)
            logger.info(
                f"Combined data from {len(years)} years, final shape: {result.shape}"
            )
            return result
        else:
            logger.info(
                f"Returning data for single year, shape: {all_data[0].shape}"
            )
            return all_data[0]

    except Exception as e:
        logger.error(f"Error in _load: {str(e)}")
        raise

scf_data = load_scf(2022)
PREDICTORS: List[str] = [
        "hhsex",  # sex of head of household
        "age",  # age of respondent
        "married",  # marital status of respondent
        # "kids",  # number of children in household
        "race",  # race of respondent
        "income",  # total annual income of household
        "wageinc",  # income from wages and salaries
        "bussefarminc",  # income from business, self-employment or farm
        "intdivinc",  # income from interest and dividends
        "ssretinc",  # income from social security and retirement accounts
        "lf",  # labor force status
    ]
IMPUTED_VARIABLES: List[str] = ["networth"]

# Evaluate predictors
predictors_evaluation = compute_predictor_correlations(scf_data, PREDICTORS, IMPUTED_VARIABLES)
print("\nMutual information with networth:")
print(predictors_evaluation["predictor_target_mi"])

X_train, X_test = preprocess_data(
    data=scf_data, full_data=False, normalize=False,
)

# Shrink down the data by sampling
X_train = X_train.sample(frac=0.01, random_state=RANDOM_STATE)
X_test = X_test.sample(frac=0.01, random_state=RANDOM_STATE)

Y_test: pd.DataFrame = X_test[IMPUTED_VARIABLES]

# 2. Run imputation methods
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg, Matching]
method_imputations = get_imputations(
    model_classes, X_train, X_test, PREDICTORS, IMPUTED_VARIABLES
)

# 3. Compare imputation methods using unified metrics function
loss_comparison_df = compare_metrics(
    Y_test, method_imputations, IMPUTED_VARIABLES
)

# 4. Plot results - filter for quantile loss metrics only
quantile_loss_df = loss_comparison_df[loss_comparison_df['Metric'] == 'quantile_loss']
comparison_viz = method_comparison_results(
    data=quantile_loss_df,
    metric="quantile_loss",
    data_format="long",
)
fig = comparison_viz.plot(
    title="Method Comparison on SCF Dataset",
    show_mean=True,
)
fig.show()

# Compare distribution similarity
for model, imputations in method_imputations.items():
    print(f"Model: {model}, distribution similarity: \n{compare_distributions(
    donor_data=pd.DataFrame(scf_data["networth"]),
    receiver_data=pd.DataFrame(imputations[0.5]["networth"]),
    imputed_variables=IMPUTED_VARIABLES,
)}")

Mutual information with networth:
              networth
hhsex         0.014095
age           0.119553
married       0.016892
race          0.027996
income        0.156195
wageinc       0.121915
bussefarminc  0.080536
intdivinc     0.085286
ssretinc      0.059902
lf            0.010602
Loading...
Model: QRF, distribution similarity: 
   Variable                Metric      Distance
0  networth  wasserstein_distance  1.200929e+07
Model: OLS, distribution similarity: 
   Variable                Metric      Distance
0  networth  wasserstein_distance  2.274842e+07
Model: QuantReg, distribution similarity: 
   Variable                Metric      Distance
0  networth  wasserstein_distance  2.818652e+07
Model: Matching, distribution similarity: 
   Variable                Metric      Distance
0  networth  wasserstein_distance  1.798087e+07