The MDN model uses neural networks with mixture density outputs to predict missing values. Built on PyTorch Tabular, it learns conditional distributions as mixtures of Gaussians, which lets it capture multi-modal relationships.
Variable type support¶
MDN automatically adapts to your target variable types. For numerical variables, it uses a mixture density network head that models the full conditional distribution as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships. For categorical and boolean variables, it switches to a neural classifier with appropriate output layers. This automatic detection means you can pass mixed variable types without manual configuration.
How it works¶
The MDN imputer uses a shared backbone architecture (configurable dense layers with dropout and batch normalization options) that feeds into specialized output heads. For numerical targets, the mixture density head outputs parameters for a mixture of Gaussian distributions: mixing coefficients, means, and variances for each component. Predictions are generated by stochastically sampling from this learned distribution rather than returning point estimates.
For categorical and boolean targets, the model uses a standard classification head with softmax outputs. Predictions are made by sampling from the predicted probability distribution, preserving the stochastic nature of imputation.
The model supports automatic caching based on data hashes, avoiding redundant retraining when the same data is encountered again. Hyperparameter tuning via Optuna is available for optimizing the number of Gaussian components and learning rate.
Key features¶
MDN can model multi-modal distributions that simpler methods cannot capture, making it suited for variables with complex conditional distributions. The neural network backbone learns non-linear relationships without explicit feature engineering.
Training leverages GPU acceleration when available and includes early stopping to prevent overfitting. The automatic model caching system speeds up repeated analyses on the same dataset.
Installation note¶
MDN requires the pytorch-tabular package, which is an optional dependency. Install it with:
pip install pytorch_tabular