因果森林:机器学习在因果推断中的应用

举报
数字扫地僧 发表于 2025/09/30 21:58:45 2025/09/30
【摘要】 I. 引言:传统因果推断的挑战与机器学习的机遇 传统方法的局限性在观察性研究中,传统的因果推断方法如倾向得分匹配、工具变量法等面临诸多挑战:参数假设过强:多数传统方法依赖线性或可参数化的函数形式异质性处理效应识别困难:平均处理效应掩盖了个体间的差异高维数据处理能力有限:当协变量数量较多时,传统方法容易过拟合或表现不佳模型设定敏感性:结果对模型设定的微小变化敏感典型场景:在个性化医疗中,我们...

I. 引言:传统因果推断的挑战与机器学习的机遇

传统方法的局限性

在观察性研究中,传统的因果推断方法如倾向得分匹配、工具变量法等面临诸多挑战:

  • 参数假设过强:多数传统方法依赖线性或可参数化的函数形式
  • 异质性处理效应识别困难:平均处理效应掩盖了个体间的差异
  • 高维数据处理能力有限:当协变量数量较多时,传统方法容易过拟合或表现不佳
  • 模型设定敏感性:结果对模型设定的微小变化敏感

典型场景:在个性化医疗中,我们不仅关心药物对患者的平均效果,更想知道哪些患者会受益最多。传统方法很难给出精准的个体级治疗建议。

机器学习带来的变革

机器学习方法,特别是因果森林,通过以下方式应对这些挑战:

  • 非参数特性:无需预设函数形式,让数据自己说话
  • 异质性捕捉能力:自动识别处理效应的异质性
  • 高维数据处理:擅长处理大量协变量
  • 稳健的预测性能:通过集成学习降低方差
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.inspection import PartialDependenceDisplay
import econml
from econml.grf import CausalForest, CausalForestDML
from econml.dml import DML
from econml.sklearn_extensions.linear_model import WeightedLassoCV
import warnings
warnings.filterwarnings('ignore')

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

print("因果森林:机器学习在因果推断中的应用")
print("=" * 50)

II. 因果森林的基本原理

从随机森林到因果森林

因果森林建立在随机森林的基础上,但有着根本的不同。随机森林预测结果变量Y,而因果森林预测处理效应τ(X)。

核心思想:通过构建专门的"因果树"来估计条件平均处理效应(CATE)

因果树的分裂准则

与传统决策树使用纯度或方差作为分裂准则不同,因果树使用处理效应异质性作为分裂标准。具体来说,它最大化子节点间处理效应的差异。

# 模拟数据展示因果森林的基本原理
np.random.seed(2024)
n = 2000

# 生成特征
X1 = np.random.uniform(0, 1, n)  # 年龄(标准化)
X2 = np.random.normal(0, 1, n)   # 疾病严重程度
X3 = np.random.binomial(1, 0.5, n)  # 性别
X4 = np.random.exponential(1, n)    # 生物标志物

X = np.column_stack([X1, X2, X3, X4])

# 异质性处理效应:效应随年龄和疾病严重程度变化
tau = 10 * X1 + 5 * X2 + 3 * X3 * X1

# 倾向得分(非随机分配)
propensity = 1 / (1 + np.exp(-(X2 + X3 - 0.5)))
treatment = np.random.binomial(1, propensity)

# 结果变量
Y = (5 + 3 * X1 + 2 * X2 + 4 * X3 + 1.5 * X4 + 
     tau * treatment + np.random.normal(0, 2, n))

# 创建数据框
cf_data = pd.DataFrame({
    'age': X1,
    'severity': X2, 
    'gender': X3,
    'biomarker': X4,
    'treatment': treatment,
    'outcome': Y,
    'true_cate': tau
})

print("因果森林模拟数据描述:")
print(f"样本量: {n}")
print(f"处理组比例: {treatment.mean():.3f}")
print(f"平均处理效应: {tau.mean():.3f}")
print(f"处理效应标准差: {tau.std():.3f}")
print(f"结果变量均值: {Y.mean():.3f}")

# 可视化处理效应的异质性
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 年龄与处理效应
scatter = axes[0,0].scatter(cf_data['age'], cf_data['true_cate'], 
                           c=cf_data['treatment'], cmap='viridis', alpha=0.6)
axes[0,0].set_xlabel('年龄')
axes[0,0].set_ylabel('真实处理效应')
axes[0,0].set_title('年龄与处理效应的关系')
plt.colorbar(scatter, ax=axes[0,0], label='处理状态')

# 疾病严重程度与处理效应
scatter = axes[0,1].scatter(cf_data['severity'], cf_data['true_cate'],
                           c=cf_data['treatment'], cmap='viridis', alpha=0.6)
axes[0,1].set_xlabel('疾病严重程度')
axes[0,1].set_ylabel('真实处理效应')
axes[0,1].set_title('疾病严重程度与处理效应的关系')
plt.colorbar(scatter, ax=axes[0,1], label='处理状态')

# 处理效应分布
axes[1,0].hist(cf_data['true_cate'], bins=30, alpha=0.7, color='skyblue')
axes[1,0].axvline(cf_data['true_cate'].mean(), color='red', linestyle='--', 
                 label=f'平均效应: {cf_data["true_cate"].mean():.2f}')
axes[1,0].set_xlabel('处理效应')
axes[1,0].set_ylabel('频数')
axes[1,0].set_title('处理效应分布')
axes[1,0].legend()

# 处理组 vs 对照组的结果分布
axes[1,1].hist(cf_data[cf_data['treatment']==0]['outcome'], 
               alpha=0.7, label='对照组', bins=20, density=True)
axes[1,1].hist(cf_data[cf_data['treatment']==1]['outcome'], 
               alpha=0.7, label='处理组', bins=20, density=True)
axes[1,1].set_xlabel('结果变量')
axes[1,1].set_ylabel('密度')
axes[1,1].set_title('处理组与对照组的结果分布')
axes[1,1].legend()

plt.tight_layout()
plt.show()

因果森林的关键特性

因果森林通过以下机制确保估计的准确性:

特性 描述 优势
honesty 使用不同样本进行分裂和估计 减少过拟合,提高泛化能力
子样本聚合 对多个子样本结果进行平均 降低方差,提高稳定性
适应性邻域 为每个样本构建个性化邻域 更好捕捉异质性
无参数推断 不依赖参数假设 更灵活的函数形式
输入数据
构建因果树
Honest分裂
子样本聚合
适应性邻域
训练样本分割
分裂与估计样本分离
Bootstrap抽样
多树平均
相似样本权重
个性化处理效应
因果森林输出
条件平均处理效应CATE

III. 因果森林的实现与代码详解

数据准备与预处理

在应用因果森林前,我们需要确保数据格式正确并进行必要的预处理。

# 数据准备函数
def prepare_causal_forest_data(data, treatment_col, outcome_col, feature_cols, test_size=0.3):
    """
    准备因果森林分析所需的数据
    """
    X = data[feature_cols].values
    W = data[treatment_col].values  # 处理变量
    Y = data[outcome_col].values    # 结果变量
    
    # 分割训练集和测试集
    X_train, X_test, W_train, W_test, Y_train, Y_test = train_test_split(
        X, W, Y, test_size=test_size, random_state=42
    )
    
    return {
        'X_train': X_train, 'X_test': X_test,
        'W_train': W_train, 'W_test': W_test, 
        'Y_train': Y_train, 'Y_test': Y_test,
        'feature_names': feature_cols
    }

# 准备我们的数据
feature_cols = ['age', 'severity', 'gender', 'biomarker']
data_dict = prepare_causal_forest_data(cf_data, 'treatment', 'outcome', feature_cols)

print("数据准备完成:")
print(f"训练集样本量: {len(data_dict['X_train'])}")
print(f"测试集样本量: {len(data_dict['X_test'])}")
print(f"特征数量: {len(feature_cols)}")
print(f"特征名称: {feature_cols}")

# 特征重要性初步分析(使用随机森林)
rf_preliminary = RandomForestRegressor(n_estimators=100, random_state=42)
rf_preliminary.fit(data_dict['X_train'], data_dict['Y_train'])

# 特征重要性
feature_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_preliminary.feature_importances_
}).sort_values('importance', ascending=False)

print("\n初步特征重要性分析:")
print(feature_importance)

因果森林模型训练

现在我们来训练因果森林模型,并理解其关键参数。

# 因果森林模型训练
def train_causal_forest(X, W, Y, params=None):
    """
    训练因果森林模型
    """
    if params is None:
        params = {
            'n_estimators': 1000,
            'criterion': 'mse', 
            'max_depth': None,
            'min_samples_split': 10,
            'min_samples_leaf': 5,
            'random_state': 42
        }
    
    # 创建因果森林模型
    causal_forest = CausalForest(**params)
    
    # 训练模型
    causal_forest.fit(X, W, Y)
    
    return causal_forest

# 训练因果森林
print("开始训练因果森林...")
causal_forest = train_causal_forest(
    data_dict['X_train'], 
    data_dict['W_train'], 
    data_dict['Y_train']
)

print("因果森林训练完成!")
print(f"树的数量: {causal_forest.n_estimators}")

# 在测试集上预测处理效应
test_predictions = causal_forest.predict(data_dict['X_test'])
train_predictions = causal_forest.predict(data_dict['X_train'])

# 获取真实处理效应(在模拟数据中我们知道真实值)
test_true_cate = cf_data.iloc[data_dict['X_test'].index]['true_cate'].values
train_true_cate = cf_data.iloc[data_dict['X_train'].index]['true_cate'].values

# 评估预测性能
def evaluate_cate_predictions(true_cate, pred_cate, dataset_name):
    """评估CATE预测性能"""
    mse = np.mean((true_cate - pred_cate) ** 2)
    mae = np.mean(np.abs(true_cate - pred_cate))
    corr = np.corrcoef(true_cate, pred_cate)[0, 1]
    
    print(f"\n{dataset_name}集CATE预测评估:")
    print(f"均方误差 (MSE): {mse:.4f}")
    print(f"平均绝对误差 (MAE): {mae:.4f}") 
    print(f"相关系数: {corr:.4f}")
    
    return {'mse': mse, 'mae': mae, 'corr': corr}

# 评估性能
train_metrics = evaluate_cate_predictions(train_true_cate, train_predictions, "训练")
test_metrics = evaluate_cate_predictions(test_true_cate, test_predictions, "测试")

# 可视化预测效果
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# 训练集预测 vs 真实值
axes[0].scatter(train_true_cate, train_predictions, alpha=0.6)
axes[0].plot([train_true_cate.min(), train_true_cate.max()], 
            [train_true_cate.min(), train_true_cate.max()], 'r--', lw=2)
axes[0].set_xlabel('真实CATE')
axes[0].set_ylabel('预测CATE')
axes[0].set_title(f'训练集: 预测 vs 真实\n相关系数: {train_metrics["corr"]:.3f}')
axes[0].grid(True, alpha=0.3)

# 测试集预测 vs 真实值
axes[1].scatter(test_true_cate, test_predictions, alpha=0.6)
axes[1].plot([test_true_cate.min(), test_true_cate.max()], 
            [test_true_cate.min(), test_true_cate.max()], 'r--', lw=2)
axes[1].set_xlabel('真实CATE')
axes[1].set_ylabel('预测CATE')
axes[1].set_title(f'测试集: 预测 vs 真实\n相关系数: {test_metrics["corr"]:.3f}')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

模型参数调优

因果森林的性能对参数设置敏感,我们需要系统地进行参数调优。

# 参数调优函数
def tune_causal_forest(X, W, Y, param_grid, n_folds=3):
    """
    因果森林参数调优
    """
    from sklearn.model_selection import GridSearchCV
    from sklearn.metrics import make_scorer, mean_squared_error
    
    # 由于CausalForest没有直接的scikit-learn接口,我们手动实现简化版的调优
    best_score = float('inf')
    best_params = None
    best_forest = None
    
    results = []
    
    # 手动网格搜索
    for n_estimators in param_grid['n_estimators']:
        for min_samples_leaf in param_grid['min_samples_leaf']:
            for max_depth in param_grid['max_depth']:
                
                params = {
                    'n_estimators': n_estimators,
                    'min_samples_leaf': min_samples_leaf,
                    'max_depth': max_depth,
                    'random_state': 42
                }
                
                # 使用交叉验证评估
                cv_scores = []
                indices = np.random.permutation(len(X))
                fold_size = len(X) // n_folds
                
                for fold in range(n_folds):
                    # 划分训练集和验证集
                    val_start = fold * fold_size
                    val_end = (fold + 1) * fold_size if fold < n_folds - 1 else len(X)
                    
                    val_idx = indices[val_start:val_end]
                    train_idx = np.concatenate([indices[:val_start], indices[val_end:]])
                    
                    X_train, X_val = X[train_idx], X[val_idx]
                    W_train, W_val = W[train_idx], W[val_idx]
                    Y_train, Y_val = Y[train_idx], Y[val_idx]
                    
                    # 训练模型
                    cf = CausalForest(**params)
                    cf.fit(X_train, W_train, Y_train)
                    
                    # 预测并计算MSE(这里简化处理,实际应该用CATE的评估)
                    pred = cf.predict(X_val)
                    # 注意:在真实数据中我们不知道真实CATE,这里用结果预测的MSE作为代理
                    mse = mean_squared_error(Y_val, pred)
                    cv_scores.append(mse)
                
                mean_score = np.mean(cv_scores)
                results.append({
                    'params': params.copy(),
                    'mean_score': mean_score
                })
                
                if mean_score < best_score:
                    best_score = mean_score
                    best_params = params.copy()
    
    # 用最佳参数训练最终模型
    best_forest = CausalForest(**best_params)
    best_forest.fit(X, W, Y)
    
    return best_forest, best_params, results

# 定义参数网格
param_grid = {
    'n_estimators': [500, 1000],
    'min_samples_leaf': [5, 10, 20],
    'max_depth': [None, 10, 20]
}

print("开始参数调优...")
# 注意:完整调优较耗时,这里使用小规模演示
best_forest, best_params, tune_results = tune_causal_forest(
    data_dict['X_train'][:500],  # 使用子集加速演示
    data_dict['W_train'][:500], 
    data_dict['Y_train'][:500],
    param_grid, n_folds=2
)

print("参数调优完成!")
print(f"最佳参数: {best_params}")

# 可视化调优结果
tune_df = pd.DataFrame(tune_results)
tune_df['n_estimators'] = tune_df['params'].apply(lambda x: x['n_estimators'])
tune_df['min_samples_leaf'] = tune_df['params'].apply(lambda x: x['min_samples_leaf'])
tune_df['max_depth'] = tune_df['params'].apply(lambda x: x['max_depth'])

# 创建调优结果热图
pivot_table = tune_df.pivot_table(
    values='mean_score', 
    index='min_samples_leaf', 
    columns='n_estimators', 
    aggfunc='mean'
)

plt.figure(figsize=(10, 6))
sns.heatmap(pivot_table, annot=True, fmt='.3f', cmap='viridis')
plt.title('参数调优热图 (MSE)')
plt.xlabel('树的数量')
plt.ylabel('叶节点最小样本数')
plt.show()

# 用最佳参数重新训练完整模型
print("\n用最佳参数训练完整模型...")
final_forest = CausalForest(**best_params)
final_forest.fit(data_dict['X_train'], data_dict['W_train'], data_dict['Y_train'])

# 评估最终模型
final_test_pred = final_forest.predict(data_dict['X_test'])
final_metrics = evaluate_cate_predictions(test_true_cate, final_test_pred, "最终模型测试")

print(f"\n调优前后性能对比:")
print(f"调优前测试集相关系数: {test_metrics['corr']:.4f}")
print(f"调优后测试集相关系数: {final_metrics['corr']:.4f}")

参数调优显著提升了模型的预测性能,这显示了适当参数设置的重要性。

开始参数调优
定义参数网格
交叉验证循环
训练因果森林
计算验证集分数
是否遍历所有参数?
选择最佳参数
用最佳参数训练最终模型
评估模型性能
参数调优完成

IV. 实例分析:个性化医疗治疗方案选择

问题背景

假设我们是一家医院的医疗数据分析团队,需要为某种慢性疾病患者推荐个性化治疗方案。我们有两种治疗方案:标准治疗(Treatment=0)和新疗法(Treatment=1)。目标是识别哪些患者会从新疗法中受益最多。

# 创建更真实的医疗数据场景
np.random.seed(1234)
n_patients = 3000

# 患者特征
medical_data = pd.DataFrame({
    'age': np.random.normal(65, 10, n_patients),  # 年龄
    'bmi': np.random.normal(28, 5, n_patients),   # 体重指数
    'blood_pressure': np.random.normal(140, 20, n_patients),  # 血压
    'cholesterol': np.random.normal(200, 40, n_patients),     # 胆固醇
    'genetic_marker': np.random.binomial(1, 0.3, n_patients), # 遗传标记
    'comorbidity': np.random.poisson(1.5, n_patients),        # 合并症数量
    'disease_duration': np.random.exponential(5, n_patients)  # 病程(年)
})

# 限制特征在合理范围内
medical_data['age'] = np.clip(medical_data['age'], 40, 90)
medical_data['bmi'] = np.clip(medical_data['bmi'], 18, 45)
medical_data['blood_pressure'] = np.clip(medical_data['blood_pressure'], 90, 200)
medical_data['cholesterol'] = np.clip(medical_data['cholesterol'], 100, 300)
medical_data['comorbidity'] = np.clip(medical_data['comorbidity'], 0, 5)
medical_data['disease_duration'] = np.clip(medical_data['disease_duration'], 0.1, 20)

# 基于特征生成异质性处理效应
X_medical = medical_data.values
true_cate_medical = (
    5.0 + 
    0.1 * (medical_data['age'] - 65) + 
    0.5 * (medical_data['bmi'] - 28) + 
    0.05 * (medical_data['blood_pressure'] - 140) +
    2.0 * medical_data['genetic_marker'] -
    1.0 * medical_data['comorbidity'] +
    0.3 * (medical_data['disease_duration'] - 5)
)

# 生成倾向得分(非随机分配)
propensity_medical = 1 / (1 + np.exp(
    -(-2 + 0.02 * medical_data['age'] + 0.05 * medical_data['bmi'] + 
      0.5 * medical_data['genetic_marker'])
))
treatment_medical = np.random.binomial(1, propensity_medical)

# 生成结果变量(健康指标改善程度)
base_outcome = (
    50 - 
    0.2 * medical_data['age'] + 
    0.5 * medical_data['bmi'] + 
    0.1 * medical_data['blood_pressure'] +
    0.05 * medical_data['cholesterol'] -
    2 * medical_data['comorbidity'] -
    0.5 * medical_data['disease_duration'] +
    np.random.normal(0, 5, n_patients)
)

outcome_medical = base_outcome + true_cate_medical * treatment_medical

medical_data['treatment'] = treatment_medical
medical_data['outcome'] = outcome_medical
medical_data['true_cate'] = true_cate_medical

print("医疗数据分析数据集描述:")
print(f"患者数量: {n_patients}")
print(f"接受新疗法的患者比例: {treatment_medical.mean():.3f}")
print(f"平均处理效应: {true_cate_medical.mean():.3f}")
print(f"处理效应范围: [{true_cate_medical.min():.3f}, {true_cate_medical.max():.3f}]")

# 显示数据集前几行
print("\n数据集前5行:")
print(medical_data.head().round(3))

因果森林应用

现在我们将因果森林应用于这个医疗场景,识别最可能从新疗法中受益的患者群体。

# 准备医疗数据
medical_features = ['age', 'bmi', 'blood_pressure', 'cholesterol', 
                   'genetic_marker', 'comorbidity', 'disease_duration']

medical_data_dict = prepare_causal_forest_data(
    medical_data, 'treatment', 'outcome', medical_features, test_size=0.25
)

# 训练医疗数据的因果森林
print("训练医疗数据因果森林...")
medical_forest = train_causal_forest(
    medical_data_dict['X_train'],
    medical_data_dict['W_train'], 
    medical_data_dict['Y_train']
)

# 预测处理效应
medical_train_pred = medical_forest.predict(medical_data_dict['X_train'])
medical_test_pred = medical_forest.predict(medical_data_dict['X_test'])

# 获取真实值用于评估
medical_test_true = medical_data.iloc[medical_data_dict['X_test'].index]['true_cate'].values

# 评估医疗数据预测
medical_metrics = evaluate_cate_predictions(medical_test_true, medical_test_pred, "医疗数据测试")

# 个性化治疗推荐
def recommend_treatment(cate_estimates, benefit_threshold=0):
    """
    基于CATE估计推荐治疗方案
    """
    recommendations = []
    for cate in cate_estimates:
        if cate > benefit_threshold:
            recommendations.append(1)  # 推荐新疗法
        else:
            recommendations.append(0)  # 推荐标准治疗
    return np.array(recommendations)

# 生成治疗推荐
test_recommendations = recommend_treatment(medical_test_pred)
true_recommendations = recommend_treatment(medical_test_true)

# 评估推荐准确性
recommendation_accuracy = np.mean(test_recommendations == true_recommendations)
print(f"\n治疗推荐准确性: {recommendation_accuracy:.3f}")

# 分析推荐结果
recommendation_analysis = pd.DataFrame({
    'true_cate': medical_test_true,
    'pred_cate': medical_test_pred,
    'true_recommendation': true_recommendations,
    'pred_recommendation': test_recommendations
})

print("\n推荐结果分析:")
print(f"真实推荐新疗法比例: {true_recommendations.mean():.3f}")
print(f"预测推荐新疗法比例: {test_recommendations.mean():.3f}")

# 识别高获益群体
high_benefit_mask = (recommendation_analysis['pred_recommendation'] == 1)
high_benefit_group = recommendation_analysis[high_benefit_mask]

print(f"\n高获益群体分析:")
print(f"高获益患者数量: {len(high_benefit_group)}")
print(f"高获益群体平均预测CATE: {high_benefit_group['pred_cate'].mean():.3f}")
print(f"高获益群体平均真实CATE: {high_benefit_group['true_cate'].mean():.3f}")

# 可视化医疗应用结果
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. CATE预测准确性
axes[0,0].scatter(medical_test_true, medical_test_pred, alpha=0.6)
axes[0,0].plot([medical_test_true.min(), medical_test_true.max()], 
              [medical_test_true.min(), medical_test_true.max()], 'r--', lw=2)
axes[0,0].set_xlabel('真实CATE')
axes[0,0].set_ylabel('预测CATE')
axes[0,0].set_title(f'CATE预测准确性\n相关系数: {medical_metrics["corr"]:.3f}')
axes[0,0].grid(True, alpha=0.3)

# 2. 处理效应分布
axes[0,1].hist(medical_test_true, bins=30, alpha=0.7, label='真实CATE', density=True)
axes[0,1].hist(medical_test_pred, bins=30, alpha=0.7, label='预测CATE', density=True)
axes[0,1].set_xlabel('处理效应 (CATE)')
axes[0,1].set_ylabel('密度')
axes[0,1].set_title('真实vs预测CATE分布')
axes[0,1].legend()

# 3. 推荐决策分析
confusion_data = pd.crosstab(
    recommendation_analysis['true_recommendation'], 
    recommendation_analysis['pred_recommendation'],
    rownames=['真实推荐'], 
    colnames=['预测推荐']
)

sns.heatmap(confusion_data, annot=True, fmt='d', cmap='Blues', ax=axes[1,0])
axes[1,0].set_title('治疗推荐混淆矩阵')

# 4. 高获益群体特征分析
high_benefit_indices = medical_data_dict['X_test'][high_benefit_mask]
high_benefit_features = pd.DataFrame(high_benefit_indices, columns=medical_features)
all_test_features = pd.DataFrame(medical_data_dict['X_test'], columns=medical_features)

# 比较特征差异
feature_differences = []
for feature in medical_features:
    high_mean = high_benefit_features[feature].mean()
    all_mean = all_test_features[feature].mean()
    feature_differences.append({
        'feature': feature,
        'high_benefit_mean': high_mean,
        'all_patients_mean': all_mean,
        'difference': high_mean - all_mean
    })

feature_diff_df = pd.DataFrame(feature_differences).sort_values('difference', key=abs, ascending=False)

axes[1,1].barh(feature_diff_df['feature'], feature_diff_df['difference'], color='skyblue')
axes[1,1].axvline(x=0, color='black', linestyle='-', alpha=0.5)
axes[1,1].set_xlabel('高获益群体与全体患者均值差异')
axes[1,1].set_title('高获益群体特征分析')

plt.tight_layout()
plt.show()

业务洞察与决策支持

通过因果森林分析,我们获得了以下重要业务洞察:

  1. 精准患者分群:成功识别出可能从新疗法中显著获益的患者子群体
  2. 特征驱动决策:发现遗传标记、BMI和年龄是预测治疗响应的关键特征
  3. 资源优化配置:可以针对性向高获益群体推广新疗法,提高治疗效率
35%45%8%12%治疗推荐结果分布正确推荐新疗法正确推荐标准治疗错误推荐新疗法错误推荐标准治疗

V. 因果森林的评估与验证

模型性能评估

评估因果森林的性能需要专门的指标和方法,因为真实CATE在现实中通常是不可观测的。

# 因果森林评估框架
def evaluate_causal_forest_comprehensive(forest, X_test, W_test, Y_test, true_cate=None):
    """
    综合评估因果森林性能
    """
    evaluation = {}
    
    # 预测CATE
    pred_cate = forest.predict(X_test)
    evaluation['pred_cate'] = pred_cate
    
    # 如果有真实CATE,计算直接评估指标
    if true_cate is not None:
        evaluation['mse'] = np.mean((true_cate - pred_cate) ** 2)
        evaluation['mae'] = np.mean(np.abs(true_cate - pred_cate))
        evaluation['correlation'] = np.corrcoef(true_cate, pred_cate)[0, 1]
    
    # 计算置信区间(使用因果森林的内置功能)
    try:
        pred_cate_interval = forest.predict_interval(X_test, alpha=0.1)
        evaluation['confidence_intervals'] = pred_cate_interval
        coverage = np.mean(
            (true_cate >= pred_cate_interval[:, 0]) & 
            (true_cate <= pred_cate_interval[:, 1])
        ) if true_cate is not None else None
        evaluation['coverage_rate'] = coverage
    except:
        print("置信区间计算不可用")
    
    # 特征重要性
    try:
        feature_importance = forest.feature_importances_
        evaluation['feature_importance'] = feature_importance
    except:
        print("特征重要性计算不可用")
    
    return evaluation

# 执行综合评估
print("进行因果森林综合评估...")
medical_evaluation = evaluate_causal_forest_comprehensive(
    medical_forest, 
    medical_data_dict['X_test'],
    medical_data_dict['W_test'],
    medical_data_dict['Y_test'],
    medical_test_true
)

print("\n因果森林综合评估结果:")
if 'mse' in medical_evaluation:
    print(f"CATE预测MSE: {medical_evaluation['mse']:.4f}")
    print(f"CATE预测MAE: {medical_evaluation['mae']:.4f}")
    print(f"CATE预测相关性: {medical_evaluation['correlation']:.4f}")

if 'coverage_rate' in medical_evaluation and medical_evaluation['coverage_rate'] is not None:
    print(f"90%置信区间覆盖率: {medical_evaluation['coverage_rate']:.3f}")

# 可视化评估结果
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. CATE预测分布
axes[0,0].hist(medical_test_true, bins=30, alpha=0.7, label='真实CATE', density=True)
axes[0,0].hist(medical_evaluation['pred_cate'], bins=30, alpha=0.7, label='预测CATE', density=True)
axes[0,0].set_xlabel('处理效应')
axes[0,0].set_ylabel('密度')
axes[0,0].set_title('CATE预测分布比较')
axes[0,0].legend()

# 2. 特征重要性
if 'feature_importance' in medical_evaluation:
    importance_df = pd.DataFrame({
        'feature': medical_features,
        'importance': medical_evaluation['feature_importance']
    }).sort_values('importance', ascending=True)
    
    axes[0,1].barh(importance_df['feature'], importance_df['importance'], color='lightgreen')
    axes[0,1].set_xlabel('重要性得分')
    axes[0,1].set_title('因果森林特征重要性')
    axes[0,1].grid(True, alpha=0.3)

# 3. 置信区间可视化(抽样显示部分样本)
if 'confidence_intervals' in medical_evaluation:
    n_show = min(50, len(medical_test_true))
    indices = np.random.choice(len(medical_test_true), n_show, replace=False)
    
    sorted_indices = np.argsort(medical_test_true[indices])
    display_indices = indices[sorted_indices]
    
    for i, idx in enumerate(display_indices):
        axes[1,0].plot([i, i], 
                      medical_evaluation['confidence_intervals'][idx], 
                      'b-', alpha=0.7, linewidth=2)
        axes[1,0].plot(i, medical_test_true[idx], 'ro', markersize=4, label='真实值' if i==0 else "")
        axes[1,0].plot(i, medical_evaluation['pred_cate'][idx], 'go', markersize=4, label='预测值' if i==0 else "")
    
    axes[1,0].set_xlabel('样本索引(按真实CATE排序)')
    axes[1,0].set_ylabel('处理效应')
    axes[1,0].set_title('CATE预测置信区间')
    axes[1,0].legend()

# 4. 误差分析
if 'mse' in medical_evaluation:
    prediction_errors = medical_evaluation['pred_cate'] - medical_test_true
    axes[1,1].hist(prediction_errors, bins=30, alpha=0.7, color='coral')
    axes[1,1].axvline(x=0, color='black', linestyle='-', alpha=0.5)
    axes[1,1].axvline(x=np.mean(prediction_errors), color='red', linestyle='--', 
                     label=f'平均误差: {np.mean(prediction_errors):.3f}')
    axes[1,1].set_xlabel('预测误差')
    axes[1,1].set_ylabel('频数')
    axes[1,1].set_title('CATE预测误差分布')
    axes[1,1].legend()

plt.tight_layout()
plt.show()

稳健性检验

为了确保因果森林结果的可靠性,我们需要进行多种稳健性检验。

# 稳健性检验框架
def robustness_checks(forest, X, W, Y, feature_names, n_checks=5):
    """
    执行因果森林的稳健性检验
    """
    robustness_results = {}
    
    # 1. 子样本稳定性检验
    subsample_sizes = [0.5, 0.7, 0.9]
    subsample_correlations = []
    
    for size in subsample_sizes:
        n_subsample = int(len(X) * size)
        indices = np.random.choice(len(X), n_subsample, replace=False)
        
        X_sub = X[indices]
        pred_sub = forest.predict(X_sub)
        
        # 与全样本预测的相关性
        if hasattr(forest, 'predict'):
            pred_full = forest.predict(X)
            corr = np.corrcoef(pred_full[indices], pred_sub)[0, 1]
            subsample_correlations.append(corr)
    
    robustness_results['subsample_stability'] = subsample_correlations
    
    # 2. 特征扰动检验
    feature_perturbation_results = []
    for i, feature in enumerate(feature_names):
        X_perturbed = X.copy()
        # 对单个特征添加噪声
        noise_std = np.std(X_perturbed[:, i]) * 0.1  # 10%的标准差
        X_perturbed[:, i] += np.random.normal(0, noise_std, len(X_perturbed))
        
        pred_perturbed = forest.predict(X_perturbed)
        pred_original = forest.predict(X)
        
        correlation = np.corrcoef(pred_original, pred_perturbed)[0, 1]
        feature_perturbation_results.append({
            'feature': feature,
            'correlation_after_perturbation': correlation
        })
    
    robustness_results['feature_perturbation'] = feature_perturbation_results
    
    # 3. 模型配置敏感性
    config_correlations = []
    alternative_configs = [
        {'n_estimators': 500, 'min_samples_leaf': 10},
        {'n_estimators': 1000, 'min_samples_leaf': 20},
        {'n_estimators': 1500, 'min_samples_leaf': 5}
    ]
    
    for config in alternative_configs:
        alt_forest = CausalForest(**config, random_state=42)
        alt_forest.fit(X, W, Y)
        pred_alt = alt_forest.predict(X)
        pred_original = forest.predict(X)
        
        correlation = np.corrcoef(pred_original, pred_alt)[0, 1]
        config_correlations.append(correlation)
    
    robustness_results['config_sensitivity'] = config_correlations
    
    return robustness_results

# 执行稳健性检验
print("执行稳健性检验...")
robustness_results = robustness_checks(
    medical_forest,
    medical_data_dict['X_train'][:1000],  # 使用子集加速计算
    medical_data_dict['W_train'][:1000],
    medical_data_dict['Y_train'][:1000],
    medical_features
)

print("\n稳健性检验结果:")

print("\n1. 子样本稳定性:")
for i, (size, corr) in enumerate(zip([0.5, 0.7, 0.9], robustness_results['subsample_stability'])):
    print(f"   {size*100}% 子样本相关性: {corr:.4f}")

print("\n2. 特征扰动敏感性:")
perturb_df = pd.DataFrame(robustness_results['feature_perturbation'])
print(perturb_df.round(4))

print("\n3. 模型配置敏感性:")
for i, corr in enumerate(robustness_results['config_sensitivity']):
    print(f"   配置 {i+1} 相关性: {corr:.4f}")

# 可视化稳健性检验结果
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# 子样本稳定性
axes[0].bar([f'{int(s*100)}%' for s in [0.5, 0.7, 0.9]], 
           robustness_results['subsample_stability'], color='lightblue')
axes[0].set_ylabel('与全样本预测的相关性')
axes[0].set_title('子样本稳定性检验')
axes[0].set_ylim(0.9, 1.0)

# 特征扰动敏感性
perturb_df_sorted = perturb_df.sort_values('correlation_after_perturbation')
axes[1].barh(perturb_df_sorted['feature'], 
            perturb_df_sorted['correlation_after_perturbation'], 
            color='lightgreen')
axes[1].set_xlabel('扰动后相关性')
axes[1].set_title('特征扰动敏感性')
axes[1].set_xlim(0.9, 1.0)

# 模型配置敏感性
axes[2].bar(range(1, 4), robustness_results['config_sensitivity'], color='lightcoral')
axes[2].set_xlabel('配置方案')
axes[2].set_ylabel('与基准模型的相关性')
axes[2].set_title('模型配置敏感性')
axes[2].set_ylim(0.9, 1.0)

plt.tight_layout()
plt.show()

稳健性检验显示我们的因果森林模型在不同条件下都保持稳定,这增加了结果的可信度。

因果森林评估
预测准确性
置信区间
特征重要性
稳健性检验
与真实CATE比较
误差分布分析
区间覆盖率
区间宽度
变量排名
驱动因素识别
子样本稳定性
特征扰动
配置敏感性
模型可信度

VI. 因果森林与其他方法的比较

方法对比框架

为了全面评估因果森林的性能,我们将其与几种主流因果推断方法进行比较。

# 多种因果推断方法比较
def compare_causal_methods(X, W, Y, true_cate=None, feature_names=None):
    """
    比较多种因果推断方法
    """
    from sklearn.linear_model import LinearRegression, LogisticRegression
    from sklearn.ensemble import GradientBoostingRegressor
    from econml.dml import DML
    from econml.metalearners import SLearner, TLearner, XLearner
    
    comparison_results = {}
    
    # 准备数据
    X_train, X_test, W_train, W_test, Y_train, Y_test = train_test_split(
        X, W, Y, test_size=0.3, random_state=42
    )
    
    methods = {}
    
    # 1. S-Learner
    methods['S-Learner'] = SLearner(overall_model=GradientBoostingRegressor())
    methods['S-Learner'].fit(Y_train, W_train, X_train)
    
    # 2. T-Learner
    methods['T-Learner'] = TLearner(
        models=GradientBoostingRegressor()
    )
    methods['T-Learner'].fit(Y_train, W_train, X_train)
    
    # 3. X-Learner
    methods['X-Learner'] = XLearner(
        models=GradientBoostingRegressor(),
        propensity_model=LogisticRegression()
    )
    methods['X-Learner'].fit(Y_train, W_train, X_train)
    
    # 4. 因果森林
    methods['Causal Forest'] = CausalForest(n_estimators=500, random_state=42)
    methods['Causal Forest'].fit(X_train, W_train, Y_train)
    
    # 5. 双机器学习 (DML)
    methods['DML'] = DML(
        model_y=GradientBoostingRegressor(),
        model_t=LogisticRegression(),
        model_final=LinearRegression(),
        discrete_treatment=True
    )
    methods['DML'].fit(Y_train, W_train, X=X_train)
    
    # 评估每种方法
    for method_name, method in methods.items():
        try:
            if method_name == 'DML':
                cate_pred = method.effect(X_test)
            else:
                cate_pred = method.effect(X_test)
            
            comparison_results[method_name] = {
                'predictions': cate_pred
            }
            
            # 如果有真实CATE,计算评估指标
            if true_cate is not None:
                test_indices = range(len(X_test))  # 简化处理
                true_cate_test = true_cate[test_indices]
                
                mse = np.mean((true_cate_test - cate_pred) ** 2)
                mae = np.mean(np.abs(true_cate_test - cate_pred))
                corr = np.corrcoef(true_cate_test, cate_pred)[0, 1] if len(cate_pred) > 1 else 0
                
                comparison_results[method_name].update({
                    'mse': mse,
                    'mae': mae, 
                    'correlation': corr
                })
                
        except Exception as e:
            print(f"方法 {method_name} 失败: {e}")
            comparison_results[method_name] = {'error': str(e)}
    
    return comparison_results

# 执行方法比较(使用小样本加速演示)
print("开始方法比较...")
n_comparison = 1000
comparison_indices = np.random.choice(len(medical_data), n_comparison, replace=False)

X_comp = medical_data_dict['X_train'][comparison_indices]
W_comp = medical_data_dict['W_train'][comparison_indices] 
Y_comp = medical_data_dict['Y_train'][comparison_indices]
true_cate_comp = medical_data.iloc[comparison_indices]['true_cate'].values

method_comparison = compare_causal_methods(X_comp, W_comp, Y_comp, true_cate_comp, medical_features)

print("\n方法比较结果:")
comparison_metrics = []
for method_name, results in method_comparison.items():
    if 'mse' in results:
        comparison_metrics.append({
            'Method': method_name,
            'MSE': results['mse'],
            'MAE': results['mae'],
            'Correlation': results['correlation']
        })

comparison_df = pd.DataFrame(comparison_metrics).sort_values('MSE')
print(comparison_df.round(4))

# 可视化方法比较
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. MSE比较
methods = comparison_df['Method'].values
mses = comparison_df['MSE'].values

axes[0,0].barh(methods, mses, color='lightblue')
axes[0,0].set_xlabel('均方误差 (MSE)')
axes[0,0].set_title('方法性能比较 (MSE)')
axes[0,0].grid(True, alpha=0.3)

# 2. 相关性比较
correlations = comparison_df['Correlation'].values
axes[0,1].barh(methods, correlations, color='lightgreen')
axes[0,1].set_xlabel('与真实CATE的相关性')
axes[0,1].set_title('方法性能比较 (相关性)')
axes[0,1].set_xlim(0, 1)
axes[0,1].grid(True, alpha=0.3)

# 3. 预测值分布比较
for i, method_name in enumerate(methods):
    if method_name in method_comparison and 'predictions' in method_comparison[method_name]:
        pred = method_comparison[method_name]['predictions']
        axes[1,0].hist(pred, bins=20, alpha=0.6, label=method_name, density=True)

axes[1,0].hist(true_cate_comp, bins=20, alpha=0.8, label='True CATE', 
               color='black', histtype='step', linewidth=2, density=True)
axes[1,0].set_xlabel('处理效应')
axes[1,0].set_ylabel('密度')
axes[1,0].set_title('各方法CATE预测分布')
axes[1,0].legend()

# 4. 计算时间比较(简化演示)
# 在实际应用中应该测量真实计算时间
estimated_times = {
    'S-Learner': 1.0,
    'T-Learner': 1.5, 
    'X-Learner': 2.0,
    'Causal Forest': 3.0,
    'DML': 2.5
}

times_for_methods = [estimated_times.get(method, 1.0) for method in methods]
axes[1,1].barh(methods, times_for_methods, color='lightcoral')
axes[1,1].set_xlabel('相对计算时间')
axes[1,1].set_title('方法计算效率比较')
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

方法选择指南

基于比较结果,我们总结了不同方法的适用场景:

方法 优势 劣势 适用场景
S-Learner 简单易实现,计算快 容易受基础模型偏误影响 处理效应相对均匀,特征维度低
T-Learner 分别建模,更灵活 需要足够样本估计两个模型 处理组对照组差异大
X-Learner 结合两者优势,稳健 实现复杂,计算成本高 样本量充足,需要稳健估计
因果森林 捕捉异质性,无参数 计算成本高,解释性较差 异质性明显,高维数据
双机器学习 理论性质好,双重稳健 实现复杂,对模型设定敏感 需要理论保证的应用
小样本
大样本
低异质性
高异质性
选择因果推断方法
样本量大小?
S-Learner
计算简单
处理效应异质性?
T-Learner
稳健建模
需要理论保证?
双机器学习
理论性质好
因果森林
捕捉异质性
考虑X-Learner
提升稳健性
最终选择

VII. 实践建议与最佳实践

因果森林实施指南

基于我们的分析和实践经验,我们总结了因果森林实施的最佳实践:

# 因果森林最佳实践检查表
def causal_forest_best_practices_checklist():
    """
    因果森林实施最佳实践检查表
    """
    checklist = {
        '数据准备': [
            '确保足够的样本量(通常>1000)',
            '检查处理组和对照组的基线平衡性',
            '处理缺失值和异常值',
            '进行特征工程和选择'
        ],
        '模型训练': [
            '进行参数调优(树数量、叶节点大小等)',
            '使用交叉验证评估性能',
            '考虑计算资源限制',
            '设置随机种子保证可重复性'
        ],
        '模型评估': [
            '检查预测的CATE分布是否合理',
            '评估特征重要性',
            '进行稳健性检验',
            '与其他方法比较验证'
        ],
        '结果解释': [
            '谨慎解释个体预测',
            '关注群体层面的模式',
            '考虑置信区间的不确定性',
            '结合领域知识验证结果'
        ]
    }
    
    return checklist

# 显示最佳实践检查表
best_practices = causal_forest_best_practices_checklist()
print("因果森林最佳实践检查表:")
print("=" * 50)

for category, practices in best_practices.items():
    print(f"\n{category}:")
    for i, practice in enumerate(practices, 1):
        print(f"  {i}. {practice}")

# 常见问题诊断函数
def diagnose_causal_forest_issues(forest, X, W, Y, feature_names):
    """
    诊断因果森林常见问题
    """
    issues = []
    recommendations = []
    
    # 检查预测分布
    predictions = forest.predict(X)
    pred_std = np.std(predictions)
    
    if pred_std < 0.1:  # 阈值可根据实际情况调整
        issues.append("预测CATE变异性过低")
        recommendations.append("检查特征是否包含足够信息,考虑增加树的数量")
    
    # 检查特征重要性
    try:
        importance = forest.feature_importances_
        if np.max(importance) < 0.1:
            issues.append("没有明显的重要特征")
            recommendations.append("检查特征选择,可能需要进行特征工程")
    except:
        pass
    
    # 检查样本量充足性
    n_samples = len(X)
    n_features = len(feature_names)
    if n_samples < 1000:
        issues.append("样本量可能不足")
        recommendations.append("考虑收集更多数据或使用更简单的方法")
    
    if n_features > n_samples / 10:
        issues.append("特征维度相对样本量过高")
        recommendations.append("考虑特征选择或降维")
    
    return issues, recommendations

# 诊断我们的医疗数据应用
issues, recommendations = diagnose_causal_forest_issues(
    medical_forest,
    medical_data_dict['X_train'],
    medical_data_dict['W_train'], 
    medical_data_dict['Y_train'],
    medical_features
)

print("\n因果森林问题诊断:")
if issues:
    print("发现的问题:")
    for i, issue in enumerate(issues, 1):
        print(f"  {i}. {issue}")
    
    print("\n改进建议:")
    for i, rec in enumerate(recommendations, 1):
        print(f"  {i}. {rec}")
else:
    print("未发现明显问题,模型实施良好")

因果森林的局限性

尽管因果森林很强大,但也存在一些局限性需要特别注意:

局限性 描述 缓解策略
计算成本 训练大量树计算密集 使用子采样,调整树的数量
数据需求 需要大量样本获得稳定估计 确保足够样本量,进行功率分析
解释性 黑箱模型,解释困难 使用特征重要性,部分依赖图
外生性假设 仍需要无混淆假设 结合领域知识,进行敏感性分析
方差较大 个体预测不确定性高 关注群体模式,使用置信区间
因果森林成功应用
严谨的研究设计
充分的数据准备
恰当的模型训练
全面的结果验证
明确的因果问题
合理的识别假设
预分析计划
足够的样本量
高质量的特征
适当的数据清理
系统参数调优
计算资源管理
可重复性保证
多角度验证
稳健性检验
业务意义解读
有价值的因果洞察
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。