Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Mixture Density Network

The MDN model uses deep neural networks with mixture density outputs to predict missing values by learning complex, potentially multi-modal conditional distributions. Built on PyTorch Tabular, this approach combines the flexibility of neural networks with the probabilistic richness of Gaussian mixture models.

Variable type support

MDN automatically adapts to your target variable types. For numerical variables, it uses a mixture density network head that models the full conditional distribution as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships. For categorical and boolean variables, it switches to a neural classifier with appropriate output layers. This automatic detection means you can pass mixed variable types without manual configuration.

How it works

The MDN imputer uses a shared backbone architecture (configurable dense layers with dropout and batch normalization options) that feeds into specialized output heads. For numerical targets, the mixture density head outputs parameters for a mixture of Gaussian distributions: mixing coefficients, means, and variances for each component. Predictions are generated by stochastically sampling from this learned distribution rather than returning point estimates.

For categorical and boolean targets, the model uses a standard classification head with softmax outputs. Predictions are made by sampling from the predicted probability distribution, preserving the stochastic nature of imputation.

The model supports automatic caching based on data hashes, avoiding redundant retraining when the same data is encountered again. Hyperparameter tuning via Optuna is available for optimizing the number of Gaussian components and learning rate.

Key features

MDN offers several advantages for complex imputation tasks. The mixture density approach can model multi-modal distributions that simpler methods cannot capture, making it suitable for variables with complex conditional distributions. The neural network backbone can learn non-linear relationships without requiring explicit feature engineering.

Training leverages GPU acceleration when available and includes early stopping to prevent overfitting. The automatic model caching system speeds up repeated analyses on the same dataset.

Installation note

MDN requires the pytorch-tabular package, which is an optional dependency. Install it with:

pip install pytorch_tabular