This notebook provides a guide to benchmarking different imputation methods using Microimpute through two example data sets.
With sklearn’s Diabetes data set¶
For data set details refer here.
from typing import List, Type
import pandas as pd
from microimpute.comparisons import *
from microimpute.config import RANDOM_STATE
from microimpute.models import *
from microimpute.visualizations import method_comparison_results
from microimpute.evaluations import *
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")
# 1. Prepare data
diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
X_train, X_test = train_test_split(
df, test_size=0.2, random_state=RANDOM_STATE
)
predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = ["s1", "s4"] # Numerical variables
Y_test: pd.DataFrame = X_test[imputed_variables]
# 2. Run imputation methods
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg, Matching]
method_imputations = get_imputations(
model_classes, X_train, X_test, predictors, imputed_variables
)
# 3. Compare imputation methods using unified metrics function
# The function automatically detects that these are numerical variables and uses quantile loss
loss_comparison_df = compare_metrics(
Y_test, method_imputations, imputed_variables
)
# 4. Plot results - filter for quantile loss metrics only
quantile_loss_df = loss_comparison_df[loss_comparison_df['Metric'] == 'quantile_loss']
comparison_viz = method_comparison_results(
data=quantile_loss_df,
metric="quantile_loss",
data_format="long",
)
fig = comparison_viz.plot(
title="Method Comparison on Diabetes Dataset (Numerical Variables)",
)
fig.show()
# Display summary statistics
print("Summary of quantile loss by method:")
summary = quantile_loss_df[quantile_loss_df['Imputed Variable'] == 'mean_quantile_loss'].groupby('Method')['Loss'].mean()
print(summary.sort_values())
# Evaluate sensitivity to predictors
leave_one_out_results = leave_one_out_analysis(
df,
predictors,
imputed_variables,
QRF
)
predictor_inclusion_results = progressive_predictor_inclusion(
df,
predictors,
imputed_variables,
QRF
)
print(f"Optimal subset: {predictor_inclusion_results['optimal_subset']}")
print(f"Optimal loss: {predictor_inclusion_results['optimal_loss']}")
# For step-by-step details:
print(predictor_inclusion_results['results_df'])Loading...
Summary of quantile loss by method:
Method
OLS 0.012658
QuantReg 0.012670
QRF 0.016240
Matching 0.022570
Name: Loss, dtype: float64
Loading...
Loading...
Optimal subset: ['bp', 'age', 'bmi', 'sex']
Optimal loss: 0.016242569699135692
step predictor_added predictors_included avg_quantile_loss avg_log_loss \
0 1 bp [bp] 0.017815 0
1 2 age [bp, age] 0.017225 0
2 3 bmi [bp, age, bmi] 0.016578 0
3 4 sex [bp, age, bmi, sex] 0.016243 0
cumulative_improvement marginal_improvement
0 0.000000 0.000000
1 0.000590 0.000590
2 0.001236 0.000647
3 0.001572 0.000336
# On the SCF Dataset
from typing import List, Type, Optional, Union
import io
import logging
import pandas as pd
from pydantic import validate_call
import requests
import zipfile
from microimpute.comparisons import *
from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG, VALID_YEARS
from microimpute.models import *
from microimpute.utils.data import preprocess_data
from microimpute.evaluations import *
import warnings
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
# 1. Prepare data
@validate_call(config=VALIDATE_CONFIG)
def load_scf(
years: Optional[Union[int, List[int]]] = None,
columns: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Load Survey of Consumer Finances data for specified years and columns.
Args:
years: Year or list of years to load data for.
columns: List of column names to load.
Returns:
DataFrame containing the requested data.
Raises:
ValueError: If no Stata files are found in the downloaded zip
or invalid parameters
RuntimeError: If there's a network error or a problem processing
the downloaded data
"""
def scf_url(year: int) -> str:
"""Return the URL of the SCF summary microdata zip file for a year."""
if year not in VALID_YEARS:
logger.error(
f"Invalid SCF year: {year}. Valid years are {VALID_YEARS}"
)
raise
url = f"https://www.federalreserve.gov/econres/files/scfp{year}s.zip"
return url
logger.info(f"Loading SCF data with years={years}")
try:
# Identify years for download
if years is None:
years = VALID_YEARS
logger.warning(f"Using default years: {years}")
if isinstance(years, int):
years = [years]
all_data: List[pd.DataFrame] = []
for year in years:
logger.info(f"Processing data for year {year}")
try:
# Download zip file
logger.debug(f"Downloading SCF data for year {year}")
url = scf_url(year)
try:
response = requests.get(url, timeout=60)
response.raise_for_status() # Raise an error for bad responses
except requests.exceptions.RequestException as e:
logger.error(
f"Network error downloading SCF data for year {year}: {str(e)}"
)
raise
# Process zip file
z = zipfile.ZipFile(io.BytesIO(response.content))
# Find the .dta file in the zip
dta_files: List[str] = [
f for f in z.namelist() if f.endswith(".dta")
]
if not dta_files:
logger.error(
f"No Stata files found in zip for year {year}"
)
raise
# Read the Stata file
try:
logger.debug(f"Reading Stata file: {dta_files[0]}")
with z.open(dta_files[0]) as f:
df = pd.read_stata(
io.BytesIO(f.read()), columns=columns
)
logger.debug(
f"Read DataFrame with shape {df.shape}"
)
except Exception as e:
logger.error(
f"Error reading Stata file for year {year}: {str(e)}"
)
raise
# Add year column
df["year"] = year
logger.info(
f"Successfully processed data for year {year}, shape: {df.shape}"
)
all_data.append(df)
except Exception as e:
logger.error(f"Error processing year {year}: {str(e)}")
raise
# Combine all years
logger.debug(f"Combining data from {len(all_data)} years")
if len(all_data) > 1:
result = pd.concat(all_data)
logger.info(
f"Combined data from {len(years)} years, final shape: {result.shape}"
)
return result
else:
logger.info(
f"Returning data for single year, shape: {all_data[0].shape}"
)
return all_data[0]
except Exception as e:
logger.error(f"Error in _load: {str(e)}")
raise
scf_data = load_scf(2022)
PREDICTORS: List[str] = [
"hhsex", # sex of head of household
"age", # age of respondent
"married", # marital status of respondent
# "kids", # number of children in household
"race", # race of respondent
"income", # total annual income of household
"wageinc", # income from wages and salaries
"bussefarminc", # income from business, self-employment or farm
"intdivinc", # income from interest and dividends
"ssretinc", # income from social security and retirement accounts
"lf", # labor force status
]
IMPUTED_VARIABLES: List[str] = ["networth"]
# Evaluate predictors
predictors_evaluation = compute_predictor_correlations(scf_data, PREDICTORS, IMPUTED_VARIABLES)
print("\nMutual information with networth:")
print(predictors_evaluation["predictor_target_mi"])
X_train, X_test = preprocess_data(
data=scf_data, full_data=False, normalize=False,
)
# Shrink down the data by sampling
X_train = X_train.sample(frac=0.01, random_state=RANDOM_STATE)
X_test = X_test.sample(frac=0.01, random_state=RANDOM_STATE)
Y_test: pd.DataFrame = X_test[IMPUTED_VARIABLES]
# 2. Run imputation methods
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg, Matching]
method_imputations = get_imputations(
model_classes, X_train, X_test, PREDICTORS, IMPUTED_VARIABLES
)
# 3. Compare imputation methods using unified metrics function
loss_comparison_df = compare_metrics(
Y_test, method_imputations, IMPUTED_VARIABLES
)
# 4. Plot results - filter for quantile loss metrics only
quantile_loss_df = loss_comparison_df[loss_comparison_df['Metric'] == 'quantile_loss']
comparison_viz = method_comparison_results(
data=quantile_loss_df,
metric="quantile_loss",
data_format="long",
)
fig = comparison_viz.plot(
title="Method Comparison on SCF Dataset",
show_mean=True,
)
fig.show()
# Compare distribution similarity
for model, imputations in method_imputations.items():
print(f"Model: {model}, distribution similarity: \n{compare_distributions(
donor_data=pd.DataFrame(scf_data["networth"]),
receiver_data=pd.DataFrame(imputations[0.5]["networth"]),
imputed_variables=IMPUTED_VARIABLES,
)}")
Mutual information with networth:
networth
hhsex 0.014095
age 0.119553
married 0.016892
race 0.027996
income 0.156195
wageinc 0.121915
bussefarminc 0.080536
intdivinc 0.085286
ssretinc 0.059902
lf 0.010602
Loading...
Model: QRF, distribution similarity:
Variable Metric Distance
0 networth wasserstein_distance 1.200929e+07
Model: OLS, distribution similarity:
Variable Metric Distance
0 networth wasserstein_distance 2.274842e+07
Model: QuantReg, distribution similarity:
Variable Metric Distance
0 networth wasserstein_distance 2.818652e+07
Model: Matching, distribution similarity:
Variable Metric Distance
0 networth wasserstein_distance 1.798087e+07