Quantile Regression Forest (QRF) imputation#
This notebook demonstrates how to use MicroImpute’s QRF imputer to impute values using Quantile Regression Forests. QRF is a powerful machine learning technique that extends traditional random forests to predict the entire conditional distribution of a target variable.
The QRF model supports iterative imputation with a single object and workflow. Pass a list of imputed_variables
with all variables that you hope to impute for and the model will do so without needing to fit and predict for each separately.
Setup and data preparation#
# Import necessary libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.datasets import load_diabetes
import warnings
warnings.filterwarnings("ignore")
# Set pandas display options to limit table width
pd.set_option("display.width", 600)
pd.set_option("display.max_columns", 10)
pd.set_option("display.expand_frame_repr", False)
# Import MicroImpute tools
from microimpute.comparisons.data import preprocess_data
from microimpute.evaluations import *
from microimpute.models import QRF
from microimpute.config import QUANTILES
from microimpute.visualizations.plotting import model_performance_results
# Load the diabetes dataset
diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
# Display the first few rows of the dataset
df.head()
age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.038076 | 0.050680 | 0.061696 | 0.021872 | -0.044223 | -0.034821 | -0.043401 | -0.002592 | 0.019907 | -0.017646 |
1 | -0.001882 | -0.044642 | -0.051474 | -0.026328 | -0.008449 | -0.019163 | 0.074412 | -0.039493 | -0.068332 | -0.092204 |
2 | 0.085299 | 0.050680 | 0.044451 | -0.005670 | -0.045599 | -0.034194 | -0.032356 | -0.002592 | 0.002861 | -0.025930 |
3 | -0.089063 | -0.044642 | -0.011595 | -0.036656 | 0.012191 | 0.024991 | -0.036038 | 0.034309 | 0.022688 | -0.009362 |
4 | 0.005383 | -0.044642 | -0.036385 | 0.021872 | 0.003935 | 0.015596 | 0.008142 | -0.002592 | -0.031988 | -0.046641 |
# Define variables for the model
predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = [
"s1",
"s4",
] # We'll impute 's1' (total serum cholesterol) and 's4' (total cholesterol/HDL ratio)
# Create a subset with only needed columns
diabetes_df = df[predictors + imputed_variables]
# Display summary statistics
diabetes_df.describe()
age | sex | bmi | bp | s1 | s4 | |
---|---|---|---|---|---|---|
count | 4.420000e+02 | 4.420000e+02 | 4.420000e+02 | 4.420000e+02 | 4.420000e+02 | 4.420000e+02 |
mean | -2.511817e-19 | 1.230790e-17 | -2.245564e-16 | -4.797570e-17 | -1.381499e-17 | -9.042540e-18 |
std | 4.761905e-02 | 4.761905e-02 | 4.761905e-02 | 4.761905e-02 | 4.761905e-02 | 4.761905e-02 |
min | -1.072256e-01 | -4.464164e-02 | -9.027530e-02 | -1.123988e-01 | -1.267807e-01 | -7.639450e-02 |
25% | -3.729927e-02 | -4.464164e-02 | -3.422907e-02 | -3.665608e-02 | -3.424784e-02 | -3.949338e-02 |
50% | 5.383060e-03 | -4.464164e-02 | -7.283766e-03 | -5.670422e-03 | -4.320866e-03 | -2.592262e-03 |
75% | 3.807591e-02 | 5.068012e-02 | 3.124802e-02 | 3.564379e-02 | 2.835801e-02 | 3.430886e-02 |
max | 1.107267e-01 | 5.068012e-02 | 1.705552e-01 | 1.320436e-01 | 1.539137e-01 | 1.852344e-01 |
# Split data into training and testing sets
X_train, X_test, dummy_info = preprocess_data(diabetes_df)
for col, dummy_cols in dummy_info["column_mapping"].items():
if col in predictors:
predictors.remove(col)
predictors.extend(dummy_cols)
elif col in imputed_variables:
imputed_variables.remove(col)
imputed_variables.extend(dummy_cols)
# Let's see how many records we have in each set
print(f"Training set size: {X_train.shape[0]} records")
print(f"Testing set size: {X_test.shape[0]} records")
Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.
Training set size: 353 records
Testing set size: 89 records
Simulating missing data#
For this example, we’ll simulate missing data in our test set by removing the values we want to impute.
# Create a copy of the test set with missing values
X_test_missing = X_test.copy()
# Store the actual values for later comparison
actual_values = X_test_missing[imputed_variables].copy()
# Remove the values to be imputed
X_test_missing[imputed_variables] = np.nan
X_test_missing.head()
age | bmi | bp | s1 | s4 | sex_0.05068011873981862 | |
---|---|---|---|---|---|---|
287 | 0.045341 | -0.006206 | -0.015999 | NaN | NaN | 0.0 |
211 | 0.092564 | 0.036907 | 0.021872 | NaN | NaN | 0.0 |
72 | 0.063504 | -0.004050 | -0.012556 | NaN | NaN | 1.0 |
321 | 0.096197 | 0.051996 | 0.079265 | NaN | NaN | 0.0 |
73 | 0.012648 | -0.020218 | -0.002228 | NaN | NaN | 1.0 |
Training and using the QRF imputer#
Now we’ll train the QRF imputer and use it to impute the missing values in our test set.
# Define quantiles we want to model
# We'll use the default quantiles from the config module
print(f"Modeling these quantiles: {QUANTILES}")
Modeling these quantiles: [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
# Initialize the QRF imputer with some custom parameters
# You can customize the random forest by passing additional parameters
qrf_imputer = QRF()
# Fit the model with our training data
# This trains a quantile random forest model
fitted_qrf_imputer = qrf_imputer.fit(
X_train,
predictors,
imputed_variables,
n_estimators=100,
min_samples_leaf=5,
)
# Impute values in the test set
# This uses the trained QRF model to predict missing values at specified quantiles
imputed_values = fitted_qrf_imputer.predict(X_test_missing, QUANTILES)
# Display the first few imputed values at the median (0.5 quantile)
imputed_values[0.5].head()
s1 | s4 | |
---|---|---|
0 | -0.015328 | -0.039493 |
1 | 0.039710 | -0.002592 |
2 | 0.069981 | 0.034309 |
3 | 0.046589 | 0.034309 |
4 | 0.031454 | 0.034309 |
Evaluating the imputation results#
Now let’s compare the imputed values with the actual values to evaluate the performance of our imputer. To understand QRF’s power to capture variability accross quantiles let us find and plot the prediction closest to the true value across quantiles for each data point.
# Define your quantiles
quantiles = list(imputed_values.keys())
# Convert imputed_values dict to a 3D array: (n_samples, n_quantiles)
pred_matrix = np.stack(
[imputed_values[q].values.flatten() for q in quantiles], axis=1
)
# Actual values flattened
actual = actual_values.values.flatten()
# Compute absolute error matrix: shape (n_samples, n_quantiles)
abs_error = np.abs(pred_matrix - actual[:, None])
# Find index of closest prediction for each sample
closest_indices = abs_error.argmin(axis=1)
# Select the closest predictions
closest_predictions = np.array(
[pred_matrix[i, idx] for i, idx in enumerate(closest_indices)]
)
# Wrap as DataFrame for plotting
closest_df = pd.DataFrame(
{
"Actual": actual,
"ClosestPrediction": closest_predictions,
}
)
# Extract median predictions for evaluation
median_predictions = imputed_values[0.5]
# Create a scatter plot comparing actual vs. imputed values
min_val = min(actual_values.min().min(), median_predictions.min().min())
max_val = max(actual_values.max().max(), median_predictions.max().max())
# Create the scatter plot
fig = px.scatter(
closest_df,
x="Actual",
y="ClosestPrediction",
opacity=0.7,
title="Comparison of Actual vs. Imputed Values using QRF",
)
# Add the diagonal line (perfect prediction line)
fig.add_trace(
go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode="lines",
line=dict(color="red", dash="dash"),
name="Perfect Prediction",
)
)
# Update layout
fig.update_layout(
xaxis_title="Actual Values",
yaxis_title="Imputed Values",
width=750,
height=600,
template="plotly_white",
margin=dict(l=50, r=50, t=80, b=50), # Adjust margins
)
fig.show()
This scatter plot compares actual observed values with those imputed by a Quantile Regression Forest (QRF) model, providing a visual assessment of imputation accuracy. Each point represents a data record, with the x-axis showing the true value and the y-axis showing the model’s predicted value. The red dashed line represents the ideal 1:1 relationship, where predictions perfectly match actual values. Most points cluster around this line, suggesting that the QRF model effectively captures the underlying structure of the data. Importantly, the model does not appear to systematically over- or under-predict across the range, and while performance at the extremes may be weaker, the overall pattern indicates that QRF provides a reasonably accurate and unbiased approach to imputing missing values. Additionally, it is important to consider the characteristics of the diabetes dataset, which seems to show a strong linear relationship between predictors and the imputed variable. QRF’s behavior suggests strength in accurately imputing variables for datasets when such linearity assumptions do not hold.
Examining quantile predictions#
QRF provides predictions at different quantiles, allowing us to capture the entire conditional distribution of the missing values.
# Compare predictions at different quantiles for the first 5 records
quantiles_to_show = QUANTILES
comparison_df = pd.DataFrame(index=range(5))
# Add actual values
comparison_df["Actual"] = actual_values.iloc[:5, 0].values
# Add quantile predictions
for q in quantiles_to_show:
comparison_df[f"Q{int(q*100)}"] = imputed_values[q].iloc[:5, 0].values
comparison_df
Actual | Q5 | Q10 | Q15 | Q20 | ... | Q75 | Q80 | Q85 | Q90 | Q95 | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.125019 | -0.046975 | -0.041472 | -0.005697 | 0.008063 | ... | 0.020446 | 0.030078 | 0.042462 | 0.063101 | 0.063101 |
1 | -0.024960 | -0.027712 | 0.008063 | -0.027712 | -0.027712 | ... | 0.043837 | 0.043837 | 0.052093 | 0.052093 | 0.063101 |
2 | 0.103003 | -0.041472 | -0.041472 | -0.008449 | -0.041472 | ... | 0.087868 | 0.087868 | 0.087868 | 0.087868 | 0.087868 |
3 | 0.054845 | -0.018080 | -0.018080 | -0.018080 | -0.018080 | ... | 0.034206 | 0.035582 | 0.046589 | 0.057597 | 0.060349 |
4 | 0.038334 | -0.046975 | -0.046975 | -0.046975 | -0.037344 | ... | 0.001183 | 0.003935 | 0.013567 | 0.013567 | 0.031454 |
5 rows × 20 columns
Visualizing prediction intervals#
By visualizing the prediction intervals of the model’s imputations we can better understand the uncertainty in our imputed values.
# Create a prediction interval plot for the first 10 records
# Number of records to plot
n_records = 10
# Prepare data for plotting
records = list(range(n_records))
actuals = actual_values.iloc[:n_records, 0].values
medians = imputed_values[0.5].iloc[:n_records, 0].values
q30 = imputed_values[0.3].iloc[:n_records, 0].values
q70 = imputed_values[0.7].iloc[:n_records, 0].values
q10 = imputed_values[0.1].iloc[:n_records, 0].values
q90 = imputed_values[0.9].iloc[:n_records, 0].values
# Create the base figure
fig = go.Figure()
# Add 80% prediction interval (Q10-Q90)
for i in range(n_records):
fig.add_trace(
go.Scatter(
x=[i, i],
y=[q10[i], q90[i]],
mode="lines",
line=dict(width=10, color="rgba(173, 216, 230, 0.3)"),
hoverinfo="none",
showlegend=False,
)
)
# Add 40% prediction interval (Q30-Q70)
for i in range(n_records):
fig.add_trace(
go.Scatter(
x=[i, i],
y=[q30[i], q70[i]],
mode="lines",
line=dict(width=10, color="rgba(70, 130, 180, 0.5)"),
hoverinfo="none",
showlegend=False,
)
)
# Add actual values
fig.add_trace(
go.Scatter(
x=records,
y=actuals,
mode="markers",
marker=dict(color="black", size=8),
name="Actual",
)
)
# Add median predictions
fig.add_trace(
go.Scatter(
x=records,
y=medians,
mode="markers",
marker=dict(color="red", size=8),
name="Median (Q50)",
)
)
# Add dashed line for Q10
fig.add_trace(
go.Scatter(
x=[-1, -1], # Dummy points for legend
y=[0, 0], # Dummy points for legend
mode="lines",
line=dict(color="rgba(173, 216, 230, 0.3)", width=10),
name="80% PI (Q10-Q90)",
)
)
# Add dashed line for Q30
fig.add_trace(
go.Scatter(
x=[-1, -1], # Dummy points for legend
y=[0, 0], # Dummy points for legend
mode="lines",
line=dict(color="rgba(70, 130, 180, 0.5)", width=10),
name="40% PI (Q30-Q70)",
)
)
# Update layout with smaller width to fit in the book layout
fig.update_layout(
title="QRF Imputation Prediction Intervals",
xaxis=dict(
title="Data Record Index",
showgrid=True,
gridwidth=1,
gridcolor="rgba(211, 211, 211, 0.7)",
),
yaxis=dict(
title="Total Serum Cholesterol (s1)",
showgrid=True,
gridwidth=1,
gridcolor="rgba(211, 211, 211, 0.7)",
),
width=750,
height=600,
template="plotly_white",
margin=dict(l=50, r=50, t=80, b=50), # Adjust margins
legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
fig.show()
This plot visualizes the prediction intervals produced by the Quantile Regression Forest (QRF) model for imputing total serum cholesterol values across ten data records. Each vertical bar represents an 80% (light blue) or 40% (dark blue) prediction interval, capturing the model’s estimated range of plausible values based on the Q10–Q90 and Q30–Q70 quantiles, respectively. Red dots mark the model’s median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the wider intervals, indicating that the QRF model is appropriately capturing uncertainty in its imputation. The fact that the intervals are sometimes asymmetrical around the median reflects the model’s flexibility in estimating skewed or heteroskedastic distributions. Overall, the plot demonstrates that the QRF model not only provides accurate point estimates but also yields informative prediction intervals that account for uncertainty in the imputed values.
## Assesing the method’s performance
To check whether our model is overfitting and ensure robust results we can perform cross-validation and visualize the results.
predictors = ["age", "sex", "bmi", "bp"]
imputed_variables = ["s1", "s4"]
# Run cross-validation on the same data set
qrf_results = cross_validate_model(
QRF, diabetes_df, predictors, imputed_variables
)
qrf_results
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: <built-in function array> is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.
warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: <built-in function array> is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.
warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: <built-in function array> is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.
warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: <built-in function array> is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.
warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/Rtmp9oKGW8", R: "/tmp/RtmpHrOFEg"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/Rtmp9oKGW8", R: "/tmp/RtmpI8lQcp"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/Rtmp9oKGW8", R: "/tmp/RtmpKHAp98"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/Rtmp9oKGW8", R: "/tmp/RtmpKXTN60"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/RtmpHrOFEg", R: "/tmp/RtmpgxQt83"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/RtmpI8lQcp", R: "/tmp/RtmprmoIVp"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/RtmpKHAp98", R: "/tmp/Rtmpc3H4k8"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/temurin-17-jdk-amd64/lib/server:/opt/hostedtoolcache/Python/3.11.12/x64/lib"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_LIBS_SITE" redefined by R and overriding existing variable. Current: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library", R: "/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library/:/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library:/usr/lib/R/library"
warnings.warn(
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable "R_SESSION_TMPDIR" redefined by R and overriding existing variable. Current: "/tmp/RtmpKXTN60", R: "/tmp/RtmpxqkaIR"
warnings.warn(
[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 4.6s remaining: 6.9s
[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 4.6s remaining: 3.1s
[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 5.7s finished
0.05 | 0.10 | 0.15 | 0.20 | 0.25 | ... | 0.75 | 0.80 | 0.85 | 0.90 | 0.95 | |
---|---|---|---|---|---|---|---|---|---|---|---|
train | 0.001632 | 0.003243 | 0.004381 | 0.005803 | 0.006759 | ... | 0.004199 | 0.003736 | 0.003177 | 0.002626 | 0.001612 |
test | 0.005358 | 0.008299 | 0.011374 | 0.013236 | 0.016153 | ... | 0.017515 | 0.015623 | 0.013717 | 0.010847 | 0.007159 |
2 rows × 19 columns
# Plot the results
perf_results_viz = model_performance_results(
results=qrf_results,
model_name="QRF",
method_name="Cross-Validation Quantile Loss Average",
)
fig = perf_results_viz.plot(
title="QRF Cross-Validation Performance",
)
fig.show()
Tuning the QRF model#
The QRF imputer supports various parameters that can be adjusted to improve performance. To set specific values you know increase performance for your specific dataset see below. Additionally, automatic hyperparameter tunning specific to the target dataset is enabled by setting the parameter tune_hyperparameters
to True.
# Check the hyperparameters with which QRF was initialized
print(fitted_qrf_imputer.models[imputed_variables[0]].qrf.get_params())
{'bootstrap': True, 'ccp_alpha': 0.0, 'criterion': 'squared_error', 'default_quantiles': 0.5, 'max_depth': None, 'max_features': 1.0, 'max_leaf_nodes': None, 'max_samples': None, 'max_samples_leaf': 1, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 5, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'monotonic_cst': None, 'n_estimators': 100, 'n_jobs': None, 'oob_score': False, 'random_state': 42, 'verbose': 0, 'warm_start': False}
# To set specific hyperparameters pass them when fitting the model
fitted_qrf_imputer = qrf_imputer.fit(
X_train=df,
predictors=predictors,
imputed_variables=imputed_variables,
n_estimators=200,
min_samples_leaf=10,
max_depth=5,
)
print(fitted_qrf_imputer.models[imputed_variables[0]].qrf.get_params())
{'bootstrap': True, 'ccp_alpha': 0.0, 'criterion': 'squared_error', 'default_quantiles': 0.5, 'max_depth': 5, 'max_features': 1.0, 'max_leaf_nodes': None, 'max_samples': None, 'max_samples_leaf': 1, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 10, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'monotonic_cst': None, 'n_estimators': 200, 'n_jobs': None, 'oob_score': False, 'random_state': 42, 'verbose': 0, 'warm_start': False}
# To automatically tune hyperparameters to the specific dataset at hand
fitted_qrf_imputer, best_tuned_params = qrf_imputer._fit(
X_train=df,
predictors=predictors,
imputed_variables=imputed_variables,
tune_hyperparameters=True,
)
print(best_tuned_params)
{'n_estimators': 92, 'min_samples_split': 15, 'min_samples_leaf': 3, 'max_features': 0.41424062149385665, 'bootstrap': True}