Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Mixture Density Network (MDN) imputation

This notebook demonstrates how to use Microimpute’s MDN imputer to impute values using mixture density networks. MDN models the full conditional distribution of a target variable as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships.

Variable type support

The MDN model automatically handles both numerical and categorical variables. For numerical targets, it applies a mixture density network that learns the parameters of a Gaussian mixture model. For categorical targets (strings, booleans, or numerically-encoded categorical variables), it switches to using a neural classifier. This automatic adaptation happens internally without requiring any manual configuration.

MDN class

class MDN(
    layers: str = "128-64-32",
    activation: str = "ReLU",
    dropout: float = 0.0,
    use_batch_norm: bool = False,
    num_gaussian: int = 5,
    softmax_temperature: float = 1.0,
    n_samples: int = 100,
    learning_rate: float = 1e-3,
    max_epochs: int = 100,
    early_stopping_patience: int = 10,
    batch_size: int = 256,
    model_dir: str = "./microimpute_models",
    force_retrain: bool = False,
    seed: Optional[int] = RANDOM_STATE,
    log_level: Optional[str] = "WARNING"
)
ParameterTypeDefault usedDescription
layersstr“128-64-32”Network architecture as hyphen-separated layer sizes
activationstr“ReLU”Activation function (ReLU, LeakyReLU, SELU, etc.)
dropoutfloat0.0Dropout probability for regularization
use_batch_normboolFalseWhether to use batch normalization
num_gaussianint5Number of Gaussian components in the mixture
softmax_temperaturefloat1.0Temperature for mixture weight softmax
n_samplesint100Number of samples for MDN prediction
learning_ratefloat1e-3Learning rate for Adam optimizer
max_epochsint100Maximum training epochs
early_stopping_patienceint10Epochs to wait before early stopping
batch_sizeint256Training batch size
model_dirstr“./microimpute_models”Directory for caching trained models
force_retrainboolFalseIf True, skip cache and always retrain
seedint42Random seed for reproducibility
log_levelstr“WARNING”Logging level (DEBUG, INFO, WARNING, ERROR)

fit() method

def fit(
    X_train: pd.DataFrame,
    predictors: List[str],
    imputed_variables: List[str],
    weight_col: Optional[str] = None,
    tune_hyperparameters: bool = False,
    n_trials: int = 10,
    cv_folds: int = 3
) -> Union[MDNResults, Tuple[MDNResults, Dict]]
ParameterTypeDefaultDescription
X_trainpd.DataFrame-Training data with predictors and target variables
predictorsList[str]-Column names to use as predictors
imputed_variablesList[str]-Column names of variables to impute
weight_colstrNoneColumn name for sampling weights
tune_hyperparametersboolFalseEnable Optuna-based hyperparameter tuning
n_trialsint10Number of Optuna trials for tuning
cv_foldsint3Number of cross-validation folds for tuning

It returns a MDNResults object (or tuple with best hyperparameters if tuning enabled).

MDNResults.predict() method

def predict(
    X_test: pd.DataFrame,
    quantiles: Optional[List[float]] = None
) -> Dict[float, pd.DataFrame]
ParameterTypeDefaultDescription
X_testpd.DataFrame-Data to impute (with predictors)
quantilesList[float]QUANTILESQuantiles at which to return predictions

It returns a dictionary mapping quantiles to DataFrames of imputed values.

Setup and data preparation

import warnings
warnings.filterwarnings("ignore")

import logging
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_tabular").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.datasets import load_diabetes

pd.set_option("display.width", 600)
pd.set_option("display.max_columns", 10)
pd.set_option("display.expand_frame_repr", False)

from microimpute.utils.data import preprocess_data
from microimpute.evaluations import cross_validate_model
from microimpute.models import MDN
from microimpute.config import QUANTILES
from microimpute.visualizations import model_performance_results
Error importing in API mode: ImportError("dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\n  Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\n  Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)")
Trying to import in ABI mode.
# Load the diabetes dataset
diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)

df.head()
Loading...
# Define variables for the model
predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = ["s1", "s4"]

# Create a subset with only needed columns
diabetes_df = df[predictors + imputed_variables]

# Split data into training and testing sets
X_train, X_test = preprocess_data(diabetes_df)

print(f"Training set size: {X_train.shape[0]} records")
print(f"Testing set size: {X_test.shape[0]} records")
Training set size: 353 records
Testing set size: 89 records

Simulating missing data

For this example, we simulate missing data in our test set by removing the values we want to impute.

# Create a copy of the test set with missing values
X_test_missing = X_test.copy()

# Store the actual values for later comparison
actual_values = X_test_missing[imputed_variables].copy()

# Remove the values to be imputed
X_test_missing[imputed_variables] = np.nan

X_test_missing.head()
Loading...

Training and using the MDN imputer

Now we train the MDN imputer and use it to impute the missing values in our test set. The MDN model automatically caches trained models based on a hash of the input data, speeding up repeated analyses.

# Initialize the MDN imputer
mdn_imputer = MDN(
    layers="64-32",
    num_gaussian=3,
    max_epochs=30,
    early_stopping_patience=5,
)

# Fit the model
fitted_mdn_imputer = mdn_imputer.fit(
    X_train,
    predictors,
    imputed_variables,
    tune_hyperparameters=False, # Hyperparameter tuning may be enabled, note it will have high computational cost
)
# Impute values in the test set
imputed_values = fitted_mdn_imputer.predict(X_test_missing, QUANTILES)

# Display the first few imputed values at the median
imputed_values[0.5].head()
Loading...

Evaluating the imputation results

Now let’s compare the imputed values with the actual values to evaluate the performance of our imputer.

# Get quantiles and create prediction matrix
quantiles = list(imputed_values.keys())

# Convert imputed_values dict to array: (n_samples, n_quantiles)
pred_matrix = np.stack(
    [imputed_values[q].values.flatten() for q in quantiles], axis=1
)

# Actual values flattened
actual = actual_values.values.flatten()

# Compute absolute error matrix
abs_error = np.abs(pred_matrix - actual[:, None])

# Find closest prediction for each sample
closest_indices = abs_error.argmin(axis=1)
closest_predictions = np.array(
    [pred_matrix[i, idx] for i, idx in enumerate(closest_indices)]
)

# Create DataFrame for plotting
closest_df = pd.DataFrame({
    "Actual": actual,
    "ClosestPrediction": closest_predictions,
})

# Extract median predictions
median_predictions = imputed_values[0.5]

# Calculate plot bounds
min_val = min(actual_values.min().min(), median_predictions.min().min())
max_val = max(actual_values.max().max(), median_predictions.max().max())

# Create scatter plot
fig = px.scatter(
    closest_df,
    x="Actual",
    y="ClosestPrediction",
    opacity=0.7,
    title="Comparison of actual vs. imputed values using MDN",
)

# Add diagonal line (perfect prediction)
fig.add_trace(
    go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode="lines",
        line=dict(color="red", dash="dash"),
        name="Perfect prediction",
    )
)

fig.update_layout(
    xaxis_title="Actual values",
    yaxis_title="Imputed values",
    width=750,
    height=600,
    template="plotly_white",
)

fig.show()
Loading...

This scatter plot compares actual observed values with those imputed by the MDN model. Each point represents a data record, with the x-axis showing the true value and the y-axis showing the model’s prediction closest to the true value across all quantiles. The red dashed line represents the ideal 1:1 relationship. Points clustering around this line indicate that the MDN model effectively captures the underlying structure of the data.

Examining quantile predictions

MDN provides predictions at different quantiles by sampling from the learned mixture distribution, allowing us to capture the entire conditional distribution of the missing values.

# Compare predictions at different quantiles for the first 5 records
quantiles_to_show = [0.1, 0.25, 0.5, 0.75, 0.9]
comparison_df = pd.DataFrame(index=range(5))

# Add actual values
comparison_df["Actual"] = actual_values.iloc[:5, 0].values

# Add quantile predictions
for q in quantiles_to_show:
    comparison_df[f"Q{int(q*100)}"] = imputed_values[q].iloc[:5, 0].values

comparison_df
Loading...

Visualizing prediction intervals

By visualizing the prediction intervals we can better understand the uncertainty in our imputed values.

# Create prediction interval plot for first 10 records
n_records = 10

# Prepare data
records = list(range(n_records))
actuals = actual_values.iloc[:n_records, 0].values
medians = imputed_values[0.5].iloc[:n_records, 0].values
q10 = imputed_values[0.1].iloc[:n_records, 0].values
q90 = imputed_values[0.9].iloc[:n_records, 0].values
q30 = imputed_values[0.3].iloc[:n_records, 0].values
q70 = imputed_values[0.7].iloc[:n_records, 0].values

# Create figure
fig = go.Figure()

# Add 80% prediction interval (Q10-Q90)
for i in range(n_records):
    fig.add_trace(
        go.Scatter(
            x=[i, i],
            y=[q10[i], q90[i]],
            mode="lines",
            line=dict(width=10, color="rgba(173, 216, 230, 0.3)"),
            hoverinfo="none",
            showlegend=False,
        )
    )

# Add 40% prediction interval (Q30-Q70)
for i in range(n_records):
    fig.add_trace(
        go.Scatter(
            x=[i, i],
            y=[q30[i], q70[i]],
            mode="lines",
            line=dict(width=10, color="rgba(70, 130, 180, 0.5)"),
            hoverinfo="none",
            showlegend=False,
        )
    )

# Add actual values
fig.add_trace(
    go.Scatter(
        x=records,
        y=actuals,
        mode="markers",
        marker=dict(color="black", size=8),
        name="Actual",
    )
)

# Add median predictions
fig.add_trace(
    go.Scatter(
        x=records,
        y=medians,
        mode="markers",
        marker=dict(color="red", size=8),
        name="Median (Q50)",
    )
)

# Legend entries for intervals
fig.add_trace(
    go.Scatter(
        x=[-1, -1], y=[0, 0],
        mode="lines",
        line=dict(color="rgba(173, 216, 230, 0.3)", width=10),
        name="80% PI (Q10-Q90)",
    )
)
fig.add_trace(
    go.Scatter(
        x=[-1, -1], y=[0, 0],
        mode="lines",
        line=dict(color="rgba(70, 130, 180, 0.5)", width=10),
        name="40% PI (Q30-Q70)",
    )
)

fig.update_layout(
    title="MDN imputation prediction intervals",
    xaxis=dict(title="Data record index", showgrid=True),
    yaxis=dict(title="Value (s1)", showgrid=True),
    width=750,
    height=600,
    template="plotly_white",
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

fig.show()
Loading...

This plot visualizes the prediction intervals produced by the MDN model. Each vertical bar represents an 80% (light gray) or 40% (dark gray) prediction interval. Red dots mark the model’s median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the prediction intervals, indicating that the MDN model appropriately captures uncertainty in its imputation.

Assessing the method’s performance

To verify our model doesn’t overfit and ensure robust results, we perform cross-validation.

# Run cross-validation
mdn_results = cross_validate_model(
    MDN, diabetes_df, predictors, imputed_variables,
    model_hyperparams={"layers": "64-32", "num_gaussian": 3, "max_epochs": 30}
)

if "quantile_loss" in mdn_results:
    print("Quantile loss results:")
    print(mdn_results["quantile_loss"]["results"])
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:    7.6s remaining:   11.5s
[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:    7.6s remaining:    5.1s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    7.7s finished
Quantile loss results:
           0.05      0.10      0.15      0.20      0.25  ...      0.75      0.80      0.85      0.90      0.95
train  0.052342  0.080491  0.100151  0.114462  0.124877  ...  0.116854  0.105562  0.090379  0.070140  0.043119
test   0.060333  0.088338  0.109102  0.123207  0.134303  ...  0.127896  0.114887  0.099006  0.077486  0.047284

[2 rows x 19 columns]
# Plot the results
if "quantile_loss" in mdn_results:
    perf_results_viz = model_performance_results(
        results=mdn_results["quantile_loss"]["results"],
        model_name="MDN",
        method_name="Cross-validation quantile loss average",
    )
    fig = perf_results_viz.plot(
        title="MDN cross-validation performance",
    )
    fig.show()
Loading...

Hyperparameter tuning

MDN supports automatic hyperparameter tuning using Optuna. This tunes the number of Gaussian components and learning rate to optimize performance on your specific dataset. To tune hyperparameters set tune_hyperparameters=True, and access the best configuration found in the dictionary returned second when fitting the model.

fitted_tuned, best_parameters = mdn_tuned.fit(
    X_train,
    predictors,
    imputed_variables,
    tune_hyperparameters=True,
    n_trials=10,
    cv_folds=3,
)

Note this will require significant time due to multiple networks being trained in identifying the best hyperparameters.

Categorical variable imputation

MDN automatically handles categorical variables through neural classification. Let’s evaluate its performance on categorical imputation tasks.

# Create a dataset with categorical variables
np.random.seed(42)

df_categorical = pd.DataFrame()
df_categorical['age'] = df['age']
df_categorical['sex'] = df['sex']
df_categorical['bmi'] = df['bmi']
df_categorical['bp'] = df['bp']
df_categorical['risk_level'] = pd.qcut(
    df['s1'], q=3, labels=['low', 'medium', 'high']
).astype(str)

print("Categorical variable distribution:")
print(df_categorical['risk_level'].value_counts())

# Split the data
X_train_cat, X_test_cat = preprocess_data(df_categorical)

print(f"\nTraining set size: {X_train_cat.shape[0]} records")
print(f"Testing set size: {X_test_cat.shape[0]} records")
Categorical variable distribution:
risk_level
low       148
high      148
medium    146
Name: count, dtype: int64

Training set size: 353 records
Testing set size: 89 records
# Fit MDN model for categorical imputation
predictors_cat = ["age", "sex", "bmi", "bp"]
imputed_variables_cat = ["risk_level"]

mdn_cat_imputer = MDN(layers="64-32", max_epochs=50)
fitted_mdn_cat = mdn_cat_imputer.fit(X_train_cat, predictors_cat, imputed_variables_cat)

print("MDN model fitted for categorical variable imputation")

# Create test set with missing categorical values
X_test_cat_missing = X_test_cat.copy()
actual_cat_values = X_test_cat_missing[imputed_variables_cat].copy()
X_test_cat_missing[imputed_variables_cat] = np.nan

# Impute the categorical values
imputed_cat_values = fitted_mdn_cat.predict(X_test_cat_missing, [0.5])
MDN model fitted for categorical variable imputation
# Evaluate categorical imputation accuracy
from sklearn.metrics import accuracy_score, confusion_matrix

predicted = imputed_cat_values[0.5]['risk_level'].values
actual = actual_cat_values['risk_level'].values

accuracy = accuracy_score(actual, predicted)
print(f"Categorical imputation accuracy: {accuracy:.2%}")

conf_matrix = pd.DataFrame(
    confusion_matrix(actual, predicted),
    index=['Actual: low', 'Actual: medium', 'Actual: high'],
    columns=['Predicted: low', 'Predicted: medium', 'Predicted: high']
)
print("\nConfusion matrix:")
print(conf_matrix)
Categorical imputation accuracy: 31.46%

Confusion matrix:
                Predicted: low  Predicted: medium  Predicted: high
Actual: low                 10                  8               12
Actual: medium              14                 11                4
Actual: high                 7                 16                7
# Run cross-validation for categorical variables
mdn_categorical_results = cross_validate_model(
    MDN, df_categorical, predictors_cat, imputed_variables_cat,
    model_hyperparams={"layers": "64-32", "max_epochs": 30}
)

print("Categorical imputation cross-validation results (log loss):")
print(f"Mean train log loss: {mdn_categorical_results['log_loss']['mean_train']:.4f}")
print(f"Mean test log loss: {mdn_categorical_results['log_loss']['mean_test']:.4f}")
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:    0.8s remaining:    1.2s
[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:    0.9s remaining:    0.6s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    0.9s finished
Categorical imputation cross-validation results (log loss):
Mean train log loss: 0.9622
Mean test log loss: 1.0936
# Plot categorical performance
cat_perf_results_viz = model_performance_results(
    results=mdn_categorical_results,
    model_name="MDN",
    method_name="Cross-validation log loss average",
    metric="log_loss",
)
fig = cat_perf_results_viz.plot(
    title="MDN categorical imputation cross-validation performance",
)
fig.show()
Loading...

Model caching

MDN automatically caches trained models based on a hash of the input data. When you fit the model on the same data again, it loads from cache rather than retraining, significantly speeding up repeated analyses. Use force_retrain=True to bypass the cache if needed.