This notebook demonstrates how to use Microimpute’s MDN imputer to impute values using mixture density networks. MDN models the full conditional distribution of a target variable as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships.
Variable type support¶
The MDN model automatically handles both numerical and categorical variables. For numerical targets, it applies a mixture density network that learns the parameters of a Gaussian mixture model. For categorical targets (strings, booleans, or numerically-encoded categorical variables), it switches to using a neural classifier. This automatic adaptation happens internally without requiring any manual configuration.
MDN class¶
class MDN(
layers: str = "128-64-32",
activation: str = "ReLU",
dropout: float = 0.0,
use_batch_norm: bool = False,
num_gaussian: int = 5,
softmax_temperature: float = 1.0,
n_samples: int = 100,
learning_rate: float = 1e-3,
max_epochs: int = 100,
early_stopping_patience: int = 10,
batch_size: int = 256,
model_dir: str = "./microimpute_models",
force_retrain: bool = False,
seed: Optional[int] = RANDOM_STATE,
log_level: Optional[str] = "WARNING"
)| Parameter | Type | Default used | Description |
|---|---|---|---|
| layers | str | “128-64-32” | Network architecture as hyphen-separated layer sizes |
| activation | str | “ReLU” | Activation function (ReLU, LeakyReLU, SELU, etc.) |
| dropout | float | 0.0 | Dropout probability for regularization |
| use_batch_norm | bool | False | Whether to use batch normalization |
| num_gaussian | int | 5 | Number of Gaussian components in the mixture |
| softmax_temperature | float | 1.0 | Temperature for mixture weight softmax |
| n_samples | int | 100 | Number of samples for MDN prediction |
| learning_rate | float | 1e-3 | Learning rate for Adam optimizer |
| max_epochs | int | 100 | Maximum training epochs |
| early_stopping_patience | int | 10 | Epochs to wait before early stopping |
| batch_size | int | 256 | Training batch size |
| model_dir | str | “./microimpute_models” | Directory for caching trained models |
| force_retrain | bool | False | If True, skip cache and always retrain |
| seed | int | 42 | Random seed for reproducibility |
| log_level | str | “WARNING” | Logging level (DEBUG, INFO, WARNING, ERROR) |
fit() method¶
def fit(
X_train: pd.DataFrame,
predictors: List[str],
imputed_variables: List[str],
weight_col: Optional[str] = None,
tune_hyperparameters: bool = False,
n_trials: int = 10,
cv_folds: int = 3
) -> Union[MDNResults, Tuple[MDNResults, Dict]]| Parameter | Type | Default | Description |
|---|---|---|---|
| X_train | pd.DataFrame | - | Training data with predictors and target variables |
| predictors | List[str] | - | Column names to use as predictors |
| imputed_variables | List[str] | - | Column names of variables to impute |
| weight_col | str | None | Column name for sampling weights |
| tune_hyperparameters | bool | False | Enable Optuna-based hyperparameter tuning |
| n_trials | int | 10 | Number of Optuna trials for tuning |
| cv_folds | int | 3 | Number of cross-validation folds for tuning |
It returns a MDNResults object (or tuple with best hyperparameters if tuning enabled).
MDNResults.predict() method¶
def predict(
X_test: pd.DataFrame,
quantiles: Optional[List[float]] = None
) -> Dict[float, pd.DataFrame]| Parameter | Type | Default | Description |
|---|---|---|---|
| X_test | pd.DataFrame | - | Data to impute (with predictors) |
| quantiles | List[float] | QUANTILES | Quantiles at which to return predictions |
It returns a dictionary mapping quantiles to DataFrames of imputed values.
Setup and data preparation¶
import warnings
warnings.filterwarnings("ignore")
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_tabular").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.datasets import load_diabetes
pd.set_option("display.width", 600)
pd.set_option("display.max_columns", 10)
pd.set_option("display.expand_frame_repr", False)
from microimpute.utils.data import preprocess_data
from microimpute.evaluations import cross_validate_model
from microimpute.models import MDN
from microimpute.config import QUANTILES
from microimpute.visualizations import model_performance_resultsError importing in API mode: ImportError("dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\n Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)")
Trying to import in ABI mode.
# Load the diabetes dataset
diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
df.head()# Define variables for the model
predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = ["s1", "s4"]
# Create a subset with only needed columns
diabetes_df = df[predictors + imputed_variables]
# Split data into training and testing sets
X_train, X_test = preprocess_data(diabetes_df)
print(f"Training set size: {X_train.shape[0]} records")
print(f"Testing set size: {X_test.shape[0]} records")Training set size: 353 records
Testing set size: 89 records
Simulating missing data¶
For this example, we simulate missing data in our test set by removing the values we want to impute.
# Create a copy of the test set with missing values
X_test_missing = X_test.copy()
# Store the actual values for later comparison
actual_values = X_test_missing[imputed_variables].copy()
# Remove the values to be imputed
X_test_missing[imputed_variables] = np.nan
X_test_missing.head()Training and using the MDN imputer¶
Now we train the MDN imputer and use it to impute the missing values in our test set. The MDN model automatically caches trained models based on a hash of the input data, speeding up repeated analyses.
# Initialize the MDN imputer
mdn_imputer = MDN(
layers="64-32",
num_gaussian=3,
max_epochs=30,
early_stopping_patience=5,
)
# Fit the model
fitted_mdn_imputer = mdn_imputer.fit(
X_train,
predictors,
imputed_variables,
tune_hyperparameters=False, # Hyperparameter tuning may be enabled, note it will have high computational cost
)# Impute values in the test set
imputed_values = fitted_mdn_imputer.predict(X_test_missing, QUANTILES)
# Display the first few imputed values at the median
imputed_values[0.5].head()Evaluating the imputation results¶
Now let’s compare the imputed values with the actual values to evaluate the performance of our imputer.
# Get quantiles and create prediction matrix
quantiles = list(imputed_values.keys())
# Convert imputed_values dict to array: (n_samples, n_quantiles)
pred_matrix = np.stack(
[imputed_values[q].values.flatten() for q in quantiles], axis=1
)
# Actual values flattened
actual = actual_values.values.flatten()
# Compute absolute error matrix
abs_error = np.abs(pred_matrix - actual[:, None])
# Find closest prediction for each sample
closest_indices = abs_error.argmin(axis=1)
closest_predictions = np.array(
[pred_matrix[i, idx] for i, idx in enumerate(closest_indices)]
)
# Create DataFrame for plotting
closest_df = pd.DataFrame({
"Actual": actual,
"ClosestPrediction": closest_predictions,
})
# Extract median predictions
median_predictions = imputed_values[0.5]
# Calculate plot bounds
min_val = min(actual_values.min().min(), median_predictions.min().min())
max_val = max(actual_values.max().max(), median_predictions.max().max())
# Create scatter plot
fig = px.scatter(
closest_df,
x="Actual",
y="ClosestPrediction",
opacity=0.7,
title="Comparison of actual vs. imputed values using MDN",
)
# Add diagonal line (perfect prediction)
fig.add_trace(
go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode="lines",
line=dict(color="red", dash="dash"),
name="Perfect prediction",
)
)
fig.update_layout(
xaxis_title="Actual values",
yaxis_title="Imputed values",
width=750,
height=600,
template="plotly_white",
)
fig.show()This scatter plot compares actual observed values with those imputed by the MDN model. Each point represents a data record, with the x-axis showing the true value and the y-axis showing the model’s prediction closest to the true value across all quantiles. The red dashed line represents the ideal 1:1 relationship. Points clustering around this line indicate that the MDN model effectively captures the underlying structure of the data.
Examining quantile predictions¶
MDN provides predictions at different quantiles by sampling from the learned mixture distribution, allowing us to capture the entire conditional distribution of the missing values.
# Compare predictions at different quantiles for the first 5 records
quantiles_to_show = [0.1, 0.25, 0.5, 0.75, 0.9]
comparison_df = pd.DataFrame(index=range(5))
# Add actual values
comparison_df["Actual"] = actual_values.iloc[:5, 0].values
# Add quantile predictions
for q in quantiles_to_show:
comparison_df[f"Q{int(q*100)}"] = imputed_values[q].iloc[:5, 0].values
comparison_dfVisualizing prediction intervals¶
By visualizing the prediction intervals we can better understand the uncertainty in our imputed values.
# Create prediction interval plot for first 10 records
n_records = 10
# Prepare data
records = list(range(n_records))
actuals = actual_values.iloc[:n_records, 0].values
medians = imputed_values[0.5].iloc[:n_records, 0].values
q10 = imputed_values[0.1].iloc[:n_records, 0].values
q90 = imputed_values[0.9].iloc[:n_records, 0].values
q30 = imputed_values[0.3].iloc[:n_records, 0].values
q70 = imputed_values[0.7].iloc[:n_records, 0].values
# Create figure
fig = go.Figure()
# Add 80% prediction interval (Q10-Q90)
for i in range(n_records):
fig.add_trace(
go.Scatter(
x=[i, i],
y=[q10[i], q90[i]],
mode="lines",
line=dict(width=10, color="rgba(173, 216, 230, 0.3)"),
hoverinfo="none",
showlegend=False,
)
)
# Add 40% prediction interval (Q30-Q70)
for i in range(n_records):
fig.add_trace(
go.Scatter(
x=[i, i],
y=[q30[i], q70[i]],
mode="lines",
line=dict(width=10, color="rgba(70, 130, 180, 0.5)"),
hoverinfo="none",
showlegend=False,
)
)
# Add actual values
fig.add_trace(
go.Scatter(
x=records,
y=actuals,
mode="markers",
marker=dict(color="black", size=8),
name="Actual",
)
)
# Add median predictions
fig.add_trace(
go.Scatter(
x=records,
y=medians,
mode="markers",
marker=dict(color="red", size=8),
name="Median (Q50)",
)
)
# Legend entries for intervals
fig.add_trace(
go.Scatter(
x=[-1, -1], y=[0, 0],
mode="lines",
line=dict(color="rgba(173, 216, 230, 0.3)", width=10),
name="80% PI (Q10-Q90)",
)
)
fig.add_trace(
go.Scatter(
x=[-1, -1], y=[0, 0],
mode="lines",
line=dict(color="rgba(70, 130, 180, 0.5)", width=10),
name="40% PI (Q30-Q70)",
)
)
fig.update_layout(
title="MDN imputation prediction intervals",
xaxis=dict(title="Data record index", showgrid=True),
yaxis=dict(title="Value (s1)", showgrid=True),
width=750,
height=600,
template="plotly_white",
legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
fig.show()This plot visualizes the prediction intervals produced by the MDN model. Each vertical bar represents an 80% (light gray) or 40% (dark gray) prediction interval. Red dots mark the model’s median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the prediction intervals, indicating that the MDN model appropriately captures uncertainty in its imputation.
Assessing the method’s performance¶
To verify our model doesn’t overfit and ensure robust results, we perform cross-validation.
# Run cross-validation
mdn_results = cross_validate_model(
MDN, diabetes_df, predictors, imputed_variables,
model_hyperparams={"layers": "64-32", "num_gaussian": 3, "max_epochs": 30}
)
if "quantile_loss" in mdn_results:
print("Quantile loss results:")
print(mdn_results["quantile_loss"]["results"])[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 7.6s remaining: 11.5s
[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 7.6s remaining: 5.1s
[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 7.7s finished
Quantile loss results:
0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95
train 0.052342 0.080491 0.100151 0.114462 0.124877 ... 0.116854 0.105562 0.090379 0.070140 0.043119
test 0.060333 0.088338 0.109102 0.123207 0.134303 ... 0.127896 0.114887 0.099006 0.077486 0.047284
[2 rows x 19 columns]
# Plot the results
if "quantile_loss" in mdn_results:
perf_results_viz = model_performance_results(
results=mdn_results["quantile_loss"]["results"],
model_name="MDN",
method_name="Cross-validation quantile loss average",
)
fig = perf_results_viz.plot(
title="MDN cross-validation performance",
)
fig.show()Hyperparameter tuning¶
MDN supports automatic hyperparameter tuning using Optuna. This tunes the number of Gaussian components and learning rate to optimize performance on your specific dataset. To tune hyperparameters set tune_hyperparameters=True, and access the best configuration found in the dictionary returned second when fitting the model.
fitted_tuned, best_parameters = mdn_tuned.fit(
X_train,
predictors,
imputed_variables,
tune_hyperparameters=True,
n_trials=10,
cv_folds=3,
)Note this will require significant time due to multiple networks being trained in identifying the best hyperparameters.
Categorical variable imputation¶
MDN automatically handles categorical variables through neural classification. Let’s evaluate its performance on categorical imputation tasks.
# Create a dataset with categorical variables
np.random.seed(42)
df_categorical = pd.DataFrame()
df_categorical['age'] = df['age']
df_categorical['sex'] = df['sex']
df_categorical['bmi'] = df['bmi']
df_categorical['bp'] = df['bp']
df_categorical['risk_level'] = pd.qcut(
df['s1'], q=3, labels=['low', 'medium', 'high']
).astype(str)
print("Categorical variable distribution:")
print(df_categorical['risk_level'].value_counts())
# Split the data
X_train_cat, X_test_cat = preprocess_data(df_categorical)
print(f"\nTraining set size: {X_train_cat.shape[0]} records")
print(f"Testing set size: {X_test_cat.shape[0]} records")Categorical variable distribution:
risk_level
low 148
high 148
medium 146
Name: count, dtype: int64
Training set size: 353 records
Testing set size: 89 records
# Fit MDN model for categorical imputation
predictors_cat = ["age", "sex", "bmi", "bp"]
imputed_variables_cat = ["risk_level"]
mdn_cat_imputer = MDN(layers="64-32", max_epochs=50)
fitted_mdn_cat = mdn_cat_imputer.fit(X_train_cat, predictors_cat, imputed_variables_cat)
print("MDN model fitted for categorical variable imputation")
# Create test set with missing categorical values
X_test_cat_missing = X_test_cat.copy()
actual_cat_values = X_test_cat_missing[imputed_variables_cat].copy()
X_test_cat_missing[imputed_variables_cat] = np.nan
# Impute the categorical values
imputed_cat_values = fitted_mdn_cat.predict(X_test_cat_missing, [0.5])MDN model fitted for categorical variable imputation
# Evaluate categorical imputation accuracy
from sklearn.metrics import accuracy_score, confusion_matrix
predicted = imputed_cat_values[0.5]['risk_level'].values
actual = actual_cat_values['risk_level'].values
accuracy = accuracy_score(actual, predicted)
print(f"Categorical imputation accuracy: {accuracy:.2%}")
conf_matrix = pd.DataFrame(
confusion_matrix(actual, predicted),
index=['Actual: low', 'Actual: medium', 'Actual: high'],
columns=['Predicted: low', 'Predicted: medium', 'Predicted: high']
)
print("\nConfusion matrix:")
print(conf_matrix)Categorical imputation accuracy: 31.46%
Confusion matrix:
Predicted: low Predicted: medium Predicted: high
Actual: low 10 8 12
Actual: medium 14 11 4
Actual: high 7 16 7
# Run cross-validation for categorical variables
mdn_categorical_results = cross_validate_model(
MDN, df_categorical, predictors_cat, imputed_variables_cat,
model_hyperparams={"layers": "64-32", "max_epochs": 30}
)
print("Categorical imputation cross-validation results (log loss):")
print(f"Mean train log loss: {mdn_categorical_results['log_loss']['mean_train']:.4f}")
print(f"Mean test log loss: {mdn_categorical_results['log_loss']['mean_test']:.4f}")[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 0.8s remaining: 1.2s
[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 0.9s remaining: 0.6s
[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 0.9s finished
Categorical imputation cross-validation results (log loss):
Mean train log loss: 0.9622
Mean test log loss: 1.0936
# Plot categorical performance
cat_perf_results_viz = model_performance_results(
results=mdn_categorical_results,
model_name="MDN",
method_name="Cross-validation log loss average",
metric="log_loss",
)
fig = cat_perf_results_viz.plot(
title="MDN categorical imputation cross-validation performance",
)
fig.show()Model caching¶
MDN automatically caches trained models based on a hash of the input data. When you fit the model on the same data again, it loads from cache rather than retraining, significantly speeding up repeated analyses. Use force_retrain=True to bypass the cache if needed.