因果推断在生产系统的工程化落地:挑战与解决方案

举报
数字扫地僧 发表于 2025/11/29 23:58:17 2025/11/29
【摘要】 I. 因果推断的核心理论框架 1.1 潜在结果框架与识别挑战因果推断的数学基础建立在反事实推理之上。对于单元 iii,处理变量 Ti∈{0,1}T_i \in \{0,1\}Ti​∈{0,1} 和结果变量 YiY_iYi​,定义潜在结果 Yi(1)Y_i(1)Yi​(1) 和 Yi(0)Y_i(0)Yi​(0) 分别表示接受处理和未接受处理时的结果。个体因果效应为 τi=Yi(1)−Yi(...

I. 因果推断的核心理论框架

1.1 潜在结果框架与识别挑战

因果推断的数学基础建立在反事实推理之上。对于单元 ii,处理变量 Ti{0,1}T_i \in \{0,1\} 和结果变量 YiY_i,定义潜在结果 Yi(1)Y_i(1)Yi(0)Y_i(0) 分别表示接受处理和未接受处理时的结果。个体因果效应为 τi=Yi(1)Yi(0)\tau_i = Y_i(1) - Y_i(0)

关键问题在于因果识别:我们只能观测到 Yi=TiYi(1)+(1Ti)Yi(0)Y_i = T_i Y_i(1) + (1-T_i) Y_i(0),永远无法同时观测 Yi(1)Y_i(1)Yi(0)Y_i(0)。识别策略的核心是消除选择偏差

E[YiTi=1]E[YiTi=0]=E[Yi(1)Yi(0)]因果效应+选择偏差混杂因素E[Y_i|T_i=1] - E[Y_i|T_i=0] = \underbrace{E[Y_i(1) - Y_i(0)]}_{\text{因果效应}} + \underbrace{\text{选择偏差}}_{\text{混杂因素}}

实例分析:电商优惠券效果评估

某电商平台向用户发放满100减10元优惠券,观测数据显示领券用户的订单转化率比未领券用户高15%。但这并非真实的因果效应,因为领取优惠券的用户本身就是更活跃、购买意向更强的高价值用户(选择偏差)。生产系统必须识别并剔除这类偏差。

识别方法 核心假设 数据要求 工程复杂度 适用场景
随机实验(RCT) 可忽略性 控制随机分配 流量充足,成本低
倾向得分匹配(PSM) 可忽略性+重叠假设 丰富的协变量 观察性研究
工具变量(IV) 排他性约束 强工具变量 存在内生性
双重差分(DID) 平行趋势 面板数据 政策评估
断点回归(RDD) 连续性假设 连续驱动变量 阈值政策

1.2 因果图与后门准则

贝叶斯网络创始人Pearl提出的**因果图(DAG)**为混杂控制提供了可视化工具。后门准则指出:若变量集 ZZ 阻断了所有从 TTYY 的后门路径,则对 ZZ 进行分层即可识别因果效应。

阻断策略
优惠券场景因果图
前门
虚假相关
混杂
分层:高价值用户
分层:低价值用户
优惠券领取
用户价值
订单转化
历史行为
后门路径

代码实现:因果图验证

# causal_graph.py
import networkx as nx
from typing import List, Tuple, Set

class CausalGraphValidator:
    """因果图结构验证器"""
    
    def __init__(self, edges: List[Tuple[str, str]]):
        """
        初始化因果图
        
        参数:
            edges: 边列表,格式为 (cause, effect)
                  例如: [("用户价值", "优惠券领取"), ("用户价值", "订单转化")]
        """
        self.graph = nx.DiGraph()
        self.graph.add_edges_from(edges)
        
        # 验证无环性
        if not nx.is_directed_acyclic_graph(self.graph):
            raise ValueError("因果图必须是无环图(DAG)")
    
    def find_backdoor_paths(self, treatment: str, outcome: str) -> List[List[str]]:
        """
        查找所有后门路径
        后门路径:从treatment出发,通过父节点到达outcome的路径
        
        返回:路径列表
        """
        all_paths = []
        
        # 找到treatment的所有父节点(直接箭头指入的节点)
        treatment_parents = list(self.graph.predecessors(treatment))
        
        for parent in treatment_parents:
            # 寻找从parent到outcome的所有路径(不经过treatment)
            for path in nx.all_simple_paths(
                self.graph, 
                source=parent, 
                target=outcome,
                cutoff=10  # 限制路径长度
            ):
                if treatment not in path:
                    all_paths.append([treatment] + path)
        
        return all_paths
    
    def is_valid_adjustment_set(self, treatment: str, 
                                outcome: str, 
                                adjustment_set: Set[str]) -> bool:
        """
        验证调整集是否满足后门准则
        
        后门准则要求:
        1. 调整集阻断了所有后门路径
        2. 不包含treatment的后代节点
        """
        # 条件1:阻断所有后门路径
        backdoor_paths = self.find_backdoor_paths(treatment, outcome)
        
        for path in backdoor_paths:
            # 路径中是否有adjustment_set的节点
            path_nodes = set(path)
            if not path_nodes & adjustment_set:
                # 存在未被阻断的后门路径
                return False
        
        # 条件2:不包含treatment的后代节点
        descendants = nx.descendants(self.graph, treatment)
        if adjustment_set & descendants:
            return False
        
        return True
    
    def get_minimal_adjustment_sets(self, treatment: str, 
                                     outcome: str) -> List[Set[str]]:
        """
        获取最小调整集(基于图算法)
        """
        # 使用networkx的近似算法寻找最小顶点割
        # 这对应于阻断所有后门路径的最小节点集
        try:
            # 构建道德图(moral graph)后寻找分隔集
            # 此处简化处理,实际需实现完整算法
            backdoor_paths = self.find_backdoor_paths(treatment, outcome)
            
            # 收集所有后门路径上的非treatment/outcome节点
            candidate_nodes = set()
            for path in backdoor_paths:
                candidate_nodes.update(path[1:-1])  # 排除首尾
            
            # 暴力搜索最小子集(生产环境应使用贪心或优化算法)
            minimal_sets = []
            for r in range(1, len(candidate_nodes) + 1):
                for subset in itertools.combinations(candidate_nodes, r):
                    if self.is_valid_adjustment_set(treatment, outcome, set(subset)):
                        minimal_sets.append(set(subset))
                        break
                if minimal_sets:
                    break
            
            return minimal_sets
            
        except Exception as e:
            print(f"计算最小调整集失败: {e}")
            return []

# 优惠券场景实例
edges = [
    ("用户历史消费", "用户价值"),
    ("用户活跃度", "用户价值"),
    ("用户价值", "优惠券领取"),
    ("用户价值", "订单转化"),
    ("优惠券领取", "订单转化"),
    ("商品类别", "订单转化")
]

validator = CausalGraphValidator(edges)

# 查找后门路径
backdoor_paths = validator.find_backdoor_paths("优惠券领取", "订单转化")
print(f"发现后门路径: {backdoor_paths}")
# 输出: [['优惠券领取', '用户价值', '订单转化']]

# 验证调整集
is_valid = validator.is_valid_adjustment_set(
    "优惠券领取", 
    "订单转化", 
    {"用户价值"}
)
print(f"调整集是否有效: {is_valid}")
# 输出: True

II. 生产系统落地的四大工程挑战

2.1 数据质量与特征工程挑战

生产环境中的数据缺陷会系统性破坏因果识别的假设:

挑战列表

挑战类型 具体表现 对因果推断的影响 工程化检测方案
混杂变量缺失 用户心理偏好未采集 可忽略性假设失效 因果图敏感性分析
测量误差 用户收入自报告偏差 变量去污能力下降 多源数据交叉验证
时序倒置 结果变量先于处理变量入库 因果方向颠倒 时间戳完整性校验
样本选择偏差 仅分析下单用户 总体效应无法外推 逆概率加权修正
数据缺失非随机 高价值用户拒绝授权 重叠假设被破坏 多重插补+敏感性分析

实例分析:金融风控中的数据陷阱

某银行构建信用卡额度调整策略时,发现"用户征信查询次数"是关键特征。但生产日志显示:

  1. 缺失机制:查询次数缺失的用户恰是征信白户(信用历史短),这类用户的风险更高
  2. 时序混淆:征信查询发生在额度调整后(用户因额度提升而申请贷款)
  3. 测量噪声:部分查询被重复记录,导致数值 inflate

这些缺陷导致简单的倾向得分匹配严重低估额度提升对风险的真实影响,模型上线后坏账率超预期增长40%。

2.2 实时性与计算复杂度挑战

因果推断算法通常涉及矩阵求逆、迭代优化等计算密集型操作,与生产系统毫秒级响应要求存在根本冲突:

算法复杂度对比 计算瓶颈 优化策略
倾向得分匹配 O(n²)距离计算 倒排索引+近似最近邻
双重机器学习 交叉拟合需10+次重训练 增量学习+离线预训练
工具变量估计 两阶段最小二乘法求逆 矩阵分解+缓存分解结果
因果森林 树结构构建耗时 预建森林+在线更新叶子权重

工程化解决方案架构

# caching_and_precompute.py
import numpy as np
from functools import lru_cache
import pickle
import redis

class ProductionCausalEngine:
    """生产级因果推断引擎(带缓存优化)"""
    
    def __init__(self, redis_client: redis.Redis, model_path: str):
        self.redis = redis_client
        self.model_path = model_path
        
        # 预加载离线训练组件
        self._load_pretrained_components()
        
        # 初始化在线学习缓存
        self.response_cache = {}
        self.propensity_cache = LRUCache(maxsize=10000)
    
    def _load_pretrained_components(self):
        """加载离线预计算结果"""
        # 1. 倾向得分模型(复杂模型,离线训练)
        with open(f"{self.model_path}/propensity_model.pkl", 'rb') as f:
            self.propensity_model = pickle.load(f)
        
        # 2. 协方差矩阵的Cholesky分解(避免在线求逆)
        self.chol_cov = np.load(f"{self.model_path}/cholesky_decomp.npy")
        
        # 3. 特征重要性权重(用于加速匹配)
        self.feature_weights = np.load(f"{self.model_path}/feature_weights.npy")
    
    @lru_cache(maxsize=1024)
    def compute_propensity_score(self, user_features_tuple: tuple) -> float:
        """
        带缓存的倾向得分计算
        
        参数:
            user_features_tuple: 用户特征元组(可哈希,用于缓存)
        """
        # 检查Redis缓存
        feature_hash = hashlib.md5(str(user_features_tuple).encode()).hexdigest()
        cached = self.redis.get(f"propensity:{feature_hash}")
        
        if cached:
            return float(cached)
        
        # 计算倾向得分
        features = np.array(user_features_tuple).reshape(1, -1)
        propensity = self.propensity_model.predict_proba(features)[0, 1]
        
        # 写入缓存(TTL=1小时)
        self.redis.setex(f"propensity:{feature_hash}", 3600, propensity)
        
        return propensity
    
    def fast_matching(self, treated_user: dict, control_pool: list, k: int = 5) -> list:
        """
        快速匹配算法(使用预计算距离)
        
        优化点:
        1. 使用加权欧氏距离,权重已离线计算
        2. 仅在候选集内计算距离,非全局搜索
        3. 使用KDTree加速最近邻查找
        """
        from scipy.spatial import cKDTree
        
        # 构建特征向量
        treated_vec = np.array([treated_user[f] for f in self.feature_names])
        
        # 加载预构建的KDTree(每6小时重建一次)
        tree_key = f"kdtree:{self.model_version}"
        tree_data = self.redis.get(tree_key)
        
        if tree_data:
            tree = pickle.loads(tree_data)
        else:
            # 从数据库加载控制组特征矩阵
            control_matrix = self._load_control_matrix()
            tree = cKDTree(control_matrix)
            self.redis.setex(tree_key, 21600, pickle.dumps(tree))
        
        # 查询k近邻
        weighted_vec = treated_vec * self.feature_weights
        distances, indices = tree.query(weighted_vec, k=k)
        
        # 返回最匹配的control用户ID
        return [control_pool[idx]["user_id"] for idx in indices]
    
    def online_cate_estimation(self, user_features: dict) -> float:
        """
        在线条件平均因果效应估计
        
        实现增量双重机器学习,避免全量重训练
        """
        # 1. 倾向得分(缓存或快速计算)
        ps = self.compute_propensity_score(tuple(user_features.values()))
        
        # 2. 结果预测(轻量级在线模型)
        y0_pred = self.outcome_model_0.partial_fit_predict(user_features)
        y1_pred = self.outcome_model_1.partial_fit_predict(user_features)
        
        # 3. 双重鲁棒估计
        # CATE = E[Y(1)-Y(0)|X] ≈ Y1_pred - Y0_pred + 倾向加权残差
        cate = y1_pred - y0_pred
        
        return cate

class LRUCache:
    """简单LRU缓存实现"""
    def __init__(self, maxsize=128):
        self.cache = {}
        self.access_order = []
        self.maxsize = maxsize
    
    def get(self, key):
        if key in self.cache:
            self.access_order.remove(key)
            self.access_order.append(key)
            return self.cache[key]
        return None
    
    def set(self, key, value):
        if key in self.cache:
            self.access_order.remove(key)
        elif len(self.cache) >= self.maxsize:
            oldest = self.access_order.pop(0)
            del self.cache[oldest]
        
        self.cache[key] = value
        self.access_order.append(key)

2.3 因果假设的可验证性与监控

生产环境无法验证可忽略性假设(Ignorability),但必须建立假设违背的检测与熔断机制:

监控维度 检测指标 告警阈值 自动响应
协变量平衡性 标准化均值差SMD SMD > 0.1 触发特征工程pipeline
倾向得分重叠度 倾向得分分布重叠率 重叠率 < 0.8 缩小实验范围
工具变量强度 第一阶段F统计量 F < 10 暂停IV估计
平行趋势检验 DID预趋势系数 p < 0.1 切换至合成控制法
稳定性 参数估计日漂移 相对变化 > 15% 启动模型重训练
# assumption_monitoring.py
import pandas as pd
from scipy import stats
import warnings

class CausalAssumptionMonitor:
    """因果假设监控系统"""
    
    def __init__(self, redis_client, alert_webhook: str):
        self.redis = redis_client
        self.alert_webhook = alert_webhook
        self.metrics_history = {}
    
    def compute_smd(self, treated: pd.DataFrame, 
                    control: pd.DataFrame, 
                    features: List[str]) -> dict:
        """
        计算标准化均值差(Standardized Mean Difference)
        用于检验倾向得分匹配后的协变量平衡性
        
        SMD = (μ_treated - μ_control) / √((σ²_treated + σ²_control)/2)
        SMD > 0.1 表示不平衡
        """
        smd_results = {}
        
        for feature in features:
            mean_t = treated[feature].mean()
            mean_c = control[feature].mean()
            
            var_t = treated[feature].var()
            var_c = control[feature].var()
            
            pooled_sd = np.sqrt((var_t + var_c) / 2)
            smd = abs(mean_t - mean_c) / pooled_sd if pooled_sd > 0 else np.inf
            
            smd_results[feature] = {
                "smd": smd,
                "is_balanced": smd < 0.1,
                "mean_diff": mean_t - mean_c
            }
            
            # 记录历史趋势
            metric_key = f"smd:{feature}"
            self._record_metric(metric_key, smd)
        
        return smd_results
    
    def check_overlap_assumption(self, propensity_scores: dict) -> dict:
        """
        检验重叠假设(Common Support)
        要求处理组和对照组的倾向得分分布有显著重叠
        
        返回重叠率等指标
        """
        treated_ps = propensity_scores["treated"]
        control_ps = propensity_scores["control"]
        
        # 计算分布重叠区域
        min_t, max_t = min(treated_ps), max(treated_ps)
        min_c, max_c = min(control_ps), max(control_ps)
        
        overlap_min = max(min_t, min_c)
        overlap_max = min(max_t, max_c)
        
        # 计算重叠区域内的样本比例
        treated_in_overlap = sum(1 for ps in treated_ps 
                               if overlap_min <= ps <= overlap_max)
        control_in_overlap = sum(1 for ps in control_ps 
                               if overlap_min <= ps <= overlap_max)
        
        overlap_rate_t = treated_in_overlap / len(treated_ps)
        overlap_rate_c = control_in_overlap / len(control_ps)
        overall_overlap = min(overlap_rate_t, overlap_rate_c)
        
        # 检查极端区域样本
        treated_trimmed = sum(1 for ps in treated_ps 
                             if ps < overlap_min or ps > overlap_max)
        
        return {
            "overlap_rate": overall_overlap,
            "treated_trimmed_ratio": treated_trimmed / len(treated_ps),
            "is_satisfied": overall_overlap > 0.8,
            "recommended_ps_range": [overlap_min, overlap_max]
        }
    
    def check_iv_strength(self, first_stage_results: dict) -> dict:
        """
        检验工具变量的强度(第一阶段F统计量)
        F < 10 表示弱工具变量,会导致估计偏差和内生性
        
        弱工具变量会放大第二阶段的估计方差
        """
        f_statistic = first_stage_results["f_statistic"]
        partial_r2 = first_stage_results["partial_r2"]
        
        # Stock-Yogo临界值参考
        weak_iv_threshold = 10
        
        alert_level = "CRITICAL" if f_statistic < 10 else \
                     "WARNING" if f_statistic < 16.38 else "OK"
        
        return {
            "f_statistic": f_statistic,
            "partial_r2": partial_r2,
            "is_strong": f_statistic > weak_iv_threshold,
            "alert_level": alert_level,
            "recommendation": "暂停IV分析" if alert_level == "CRITICAL" else \
                            "谨慎解读结果" if alert_level == "WARNING" else "继续监控"
        }
    
    def check_parallel_trends(self, pre_treatment_outcomes: dict) -> dict:
        """
        双重差分(DID)的平行趋势检验
        
        在policy实施前,处理组和对照组应有相同趋势
        通过检验政策前时期的交互项系数是否显著
        """
        # 构造回归:Y_it = α + β * Treat_i + Σγ_k * Time_k + 
        #                     Σδ_k * (Treat_i × Time_k) + ε_it
        
        # 重点检验政策前时期的δ_k是否联合显著
        pre_coeffs = pre_treatment_outcomes["interaction_coeffs"]
        pre_pvalues = pre_treatment_outcomes["pvalues"]
        
        # 联合F检验
        joint_f_stat = pre_treatment_outcomes.get("joint_f_stat", None)
        
        # 如果任一时间段的交互项显著,则违反平行趋势
        significant_pre_periods = sum(1 for p in pre_pvalues if p < 0.1)
        
        return {
            "has_violation": significant_pre_periods > 0,
            "significant_periods": significant_pre_periods,
            "joint_f_pvalue": pre_treatment_outcomes.get("joint_f_pvalue", 1.0),
            "recommendation": "更换识别策略(如合成控制)" if significant_pre_periods > 0 \
                             else "满足平行趋势假设"
        }
    
    def run_comprehensive_check(self, analysis_context: dict) -> dict:
        """
        综合假设检验流水线
        
        返回假设违背的汇总报告
        """
        results = {
            "timestamp": pd.Timestamp.now().isoformat(),
            "checks": {},
            "overall_risk": "LOW"
        }
        
        # 1. 倾向得分重叠度检查
        if "propensity_scores" in analysis_context:
            ps_check = self.check_overlap_assumption(
                analysis_context["propensity_scores"]
            )
            results["checks"]["overlap_assumption"] = ps_check
            if not ps_check["is_satisfied"]:
                results["overall_risk"] = "HIGH"
        
        # 2. 工具变量强度检查
        if "first_stage" in analysis_context:
            iv_check = self.check_iv_strength(analysis_context["first_stage"])
            results["checks"]["iv_strength"] = iv_check
            if iv_check["alert_level"] == "CRITICAL":
                results["overall_risk"] = "HIGH"
        
        # 3. 协变量平衡性检查
        if "matched_samples" in analysis_context:
            balance_check = self.compute_smd(
                analysis_context["matched_samples"]["treated"],
                analysis_context["matched_samples"]["control"],
                analysis_context["features"]
            )
            results["checks"]["covariate_balance"] = balance_check
            
            unbalanced_features = sum(1 for v in balance_check.values() 
                                    if not v["is_balanced"])
            if unbalanced_features > 0:
                results["overall_risk"] = "MEDIUM"
        
        # 4. 平行趋势检查(DID场景)
        if "pre_treatment_outcomes" in analysis_context:
            did_check = self.check_parallel_trends(
                analysis_context["pre_treatment_outcomes"]
            )
            results["checks"]["parallel_trends"] = did_check
            if did_check["has_violation"]:
                results["overall_risk"] = "HIGH"
        
        # 触发告警
        if results["overall_risk"] in ["HIGH", "MEDIUM"]:
            self._send_alert(results)
        
        return results
    
    def _record_metric(self, metric_key: str, value: float):
        """记录指标历史用于趋势分析"""
        if metric_key not in self.metrics_history:
            self.metrics_history[metric_key] = deque(maxlen=100)
        
        self.metrics_history[metric_key].append({
            "value": value,
            "timestamp": time.time()
        })
    
    def _send_alert(self, check_results: dict):
        """发送告警通知"""
        import requests
        
        message = {
            "text": f"因果假设违背告警: 风险等级 {check_results['overall_risk']}",
            "details": check_results,
            "timestamp": check_results["timestamp"]
        }
        
        try:
            requests.post(self.alert_webhook, json=message, timeout=5)
        except Exception as e:
            warnings.warn(f"告警发送失败: {e}")

III. 实例分析:电商平台智能定价系统

3.1 业务背景与问题定义

某头部电商平台在"618"大促期间,需要对1000+个SKU进行动态定价。业务核心问题是:价格调整对商品销量的真实因果效应是多少? 传统做法是直接拟合价格-销量回归模型,但存在严重内生性:

  • 需求反向因果:销量高的商品定价更高(溢价能力)
  • 竞争混淆:竞品价格变动同时影响本商品定价和销量
  • 库存约束:库存紧张时提价,同时销量受库存限制

目标:构建生产级因果推断引擎,实时估计每个SKU的价格弹性,支持动态定价决策。

3.2 因果图构建与识别策略选择

步骤1:专家知识驱动的因果图构建

通过业务访谈和数据分析,我们绘制出如下因果图:

识别策略
定价因果图
价格效应
混杂
混杂
成本作为IV
工具变量法
满减阈值
断点回归
大促前后对比
双重差分
价格
成本
竞品价格
需求
库存水平
销量
收入
节假日
用户热度
混淆路径

关键混淆路径

  • 用户热度(U)→ 价格(P):热门商品提价
  • 用户热度(U)→ 销量(S):热门商品天然卖得好

识别策略对比与选择

策略 可行性 数据要求 估计精度 工程成本 最终选择
随机实验 低(价格不能随机) 无需 ❌ 业务不可接受
倾向得分匹配 用户画像完整 ⚠️ 辅助验证
工具变量法 成本数据可用 ✅ 主策略
断点回归 满减规则明确 ⚠️ 局部效应
双重差分 历史面板数据 ✅ 鲁棒性检验

选择工具变量法的核心理由

  1. 排他性约束:商品成本仅通过定价影响销量,不直接影响消费者需求(品牌商品例外,需剔除)
  2. 相关性:成本是定价的核心决定因素(R² ≈ 0.72)
  3. 数据可得性:采购成本在ERP系统实时同步

3.3 完整数据处理流水线

数据流架构

# data_pipeline.py
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import pandas as pd
from datetime import datetime, timedelta

class CausalDataPipeline:
    """生产级因果数据流水线"""
    
    def __init__(self, spark: SparkSession, config: dict):
        self.spark = spark
        self.config = config
        self.table_mappings = config["tables"]
    
    def extract_price_inventory_data(self, dt: str) -> DataFrame:
        """
        从ODS层抽取价格、库存、成本数据
        
        数据源:
        - dim_sku_price: 价格快照表(15分钟粒度)
        - ods_inventory: 实时库存表
        - dim_product_cost: 成本表(T+1更新)
        
        清洗逻辑:
        - 剔除负价格、异常波动价(日涨幅>50%)
        - 补全缺失库存(使用前向填充)
        - 汇率转换(跨境商品)
        """
        price_df = self.spark.table(self.table_mappings["price"]) \
            .filter(col("dt") == dt) \
            .filter(col("price") > 0) \
            .filter(abs(col("price") - col("price_lag_1d")) / col("price_lag_1d") <= 0.5) \
            .select(
                "sku_id",
                "price",
                "promotion_type",
                "dt",
                "hour"
            )
        
        inventory_df = self.spark.table(self.table_mappings["inventory"]) \
            .filter(col("dt") == dt) \
            .withColumn(
                "available_stock",
                when(col("stock_qty").isNull(), col("stock_qty_lag_1h"))
                .otherwise(col("stock_qty"))
            ) \
            .select("sku_id", "available_stock", "dt")
        
        cost_df = self.spark.table(self.table_mappings["cost"]) \
            .filter(col("dt") == dt) \
            .withColumn(
                "cost_rmb",
                when(col("currency") == "USD", col("cost_usd") * 6.8)
                .otherwise(col("cost"))
            ) \
            .select("sku_id", "cost_rmb", "dt")
        
        # 关联
        base_df = price_df.join(inventory_df, on=["sku_id", "dt"], how="left") \
            .join(cost_df, on=["sku_id", "dt"], how="left")
        
        # 缓存中间结果
        base_df.cache()
        print(f"Price-inventory data extracted: {base_df.count()} records")
        
        return base_df
    
    def extract_demand_side_features(self, dt: str, 
                                     window_days: int = 7) -> DataFrame:
        """
        构建需求侧特征(用户行为聚合)
        
        特征工程:
        - 搜索热度:近7天搜索量对数
        - 点击转化率:点击→详情页比率
        - 竞品基准价:TOP5竞品平均价
        - 季节性因子:基于历史同期的销量指数
        
        计算优化:
        - 使用窗口函数避免多次shuffle
        - 预聚合到SKU粒度
        """
        start_dt = (datetime.strptime(dt, "%Y-%m-%d") - 
                    timedelta(days=window_days)).strftime("%Y-%m-%d")
        
        # 搜索行为
        search_df = self.spark.table(self.table_mappings["search"]) \
            .filter(col("dt").between(start_dt, dt)) \
            .groupBy("sku_id") \
            .agg(
                count("*").alias("search_volume_7d"),
                avg("search_rank_position").alias("avg_search_rank")
            ) \
            .withColumn("log_search_volume", log1p(col("search_volume_7d")))
        
        # 点击行为
        click_df = self.spark.table(self.table_mappings["click"]) \
            .filter(col("dt") == dt) \
            .groupBy("sku_id") \
            .agg(
                count("*").alias("click_count"),
                sum(col("is_conversion").cast("int")).alias("conversion_count")
            ) \
            .withColumn(
                "click_to_conversion_rate",
                col("conversion_count") / col("click_count")
            )
        
        # 竞品价格(使用跨表广播join优化)
        competitor_df = self.spark.table(self.table_mappings["competitor_price"]) \
            .filter(col("dt") == dt) \
            .groupBy("sku_id") \
            .agg(avg("competitor_avg_price").alias("benchmark_price"))
        
        # 季节性因子(预计算)
        seasonal_df = self.spark.table(self.table_mappings["seasonal_factor"]) \
            .filter(col("dt") == dt) \
            .select("sku_id", "seasonal_index")
        
        # 合并特征
        feature_df = search_df.join(click_df, on="sku_id", how="left") \
            .join(competitor_df, on="sku_id", how="left") \
            .join(seasonal_df, on="sku_id", how="left") \
            .fillna(0)
        
        return feature_df
    
    def construct_panel_data(self, dt: str, 
                             lookback_days: int = 90) -> DataFrame:
        """
        构建面板数据(用于DID鲁棒性检验)
        
        结构:
        - sku_id
        - dt
        - price_treated (若该SKU当日被处理)
        - sales_volume
        - unit_cost
        - day_of_week
        - is_holiday
        
        处理变量定义:
        price_change_pct > 10% 视为"treatment"
        """
        start_dt = (datetime.strptime(dt, "%Y-%m-%d") - 
                    timedelta(days=lookback_days)).strftime("%Y-%m-%d")
        
        # 获取历史面板
        panel_df = self.spark.table(self.table_mappings["sales"]) \
            .filter(col("dt").between(start_dt, dt)) \
            .join(self.spark.table(self.table_mappings["price"]), 
                  on=["sku_id", "dt"], how="left") \
            .select(
                "sku_id",
                "dt",
                "sales_volume",
                "price",
                (col("price") - lag("price", 1).over(
                    Window.partitionBy("sku_id").orderBy("dt")
                ) / col("price")).alias("price_change_pct")
            ) \
            .withColumn(
                "treated",
                when(col("price_change_pct") > 0.1, 1).otherwise(0)
            ) \
            .withColumn(
                "post",
                when(col("dt") >= dt, 1).otherwise(0)
            )
        
        # 添加时间固定效应
        panel_df = panel_df \
            .withColumn("day_of_week", dayofweek(col("dt"))) \
            .withColumn("month", month(col("dt"))) \
            .withColumn(
                "is_holiday",
                col("dt").isin(self.config["holiday_dates"])
            )
        
        return panel_df
    
    def build_training_dataset(self, dt: str) -> pd.DataFrame:
        """
        构建训练数据集(合并所有特征)
        
        输出格式:
        - sku_id
        - price (处理变量)
        - sales_volume (结果变量)
        - unit_cost (工具变量)
        - demand_features... (控制变量)
        """
        # 抽取数据
        price_inv_df = self.extract_price_inventory_data(dt)
        demand_df = self.extract_demand_side_features(dt)
        panel_df = self.construct_panel_data(dt)
        
        # 合并
        training_df = price_inv_df.join(demand_df, on="sku_id", how="left") \
            .join(panel_df.select("sku_id", "dt", "treated", "post"), 
                  on=["sku_id", "dt"], how="left") \
            .select(
                "sku_id",
                "price",
                "sales_volume",
                "cost_rmb",
                "available_stock",
                "log_search_volume",
                "click_to_conversion_rate",
                "benchmark_price",
                "seasonal_index",
                "treated",
                "post"
            ) \
            .dropna()
        
        # 转换为pandas(下游模型需要)
        return training_df.toPandas()

# 使用示例
spark = SparkSession.builder.appName("CausalDataPipeline").getOrCreate()
config = {
    "tables": {
        "price": "dim_sku_price",
        "inventory": "ods_inventory",
        "cost": "dim_product_cost",
        "search": "dwd_search_log",
        "click": "dwd_click_log",
        "competitor_price": "ods_competitor_price",
        "seasonal_factor": "dm_seasonal_factor",
        "sales": "dwd_sales_order"
    },
    "holiday_dates": ["2024-01-01", "2024-02-10", "2024-05-01"]
}

pipeline = CausalDataPipeline(spark, config)
df_train = pipeline.build_training_dataset("2024-06-15")
print(f"Training dataset shape: {df_train.shape}")

3.4 工具变量估计的完整实现

两阶段最小二乘法(2SLS)生产级实现

# iv_estimation.py
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
import statsmodels.api as sm
from typing import Dict, Tuple, Optional

class TwoStageLeastSquares(BaseEstimator, RegressorMixin):
    """
    生产级两阶段最小二乘法实现
    
    核心优化:
    1. 支持稀疏矩阵(高维特征场景)
    2. 内置弱工具变量检测
    3. 提供鲁棒标准误(异方差稳健)
    4. 支持增量更新(在线学习)
    """
    
    def __init__(self, fit_intercept: bool = True, 
                 check_weak_iv: bool = True,
                 robust_se: bool = True):
        self.fit_intercept = fit_intercept
        self.check_weak_iv = check_weak_iv
        self.robust_se = robust_se
        
        # 缓存第一阶段结果
        self.first_stage_cache = {}
        
        # 初始化模型
        self.stage1_model = LinearRegression(fit_intercept=fit_intercept)
        self.stage2_model = LinearRegression(fit_intercept=fit_intercept)
        
        # 标准化器(防止数值问题)
        self.scaler_z = StandardScaler()  # 工具变量
        self.scaler_x = StandardScaler()  # 控制变量
    
    def fit(self, X: np.ndarray, y: np.ndarray, 
            Z: np.ndarray, T: np.ndarray,
            feature_names: Optional[list] = None) -> 'TwoStageLeastSquares':
        """
        训练2SLS模型
        
        参数:
            X: 控制变量矩阵 (n_samples, n_features)
            y: 结果变量 (n_samples,)
            Z: 工具变量矩阵 (n_samples, n_instruments)
            T: 内生处理变量 (n_samples,)
            feature_names: 特征名称(用于解释性)
        """
        n_samples, n_features = X.shape
        n_instruments = Z.shape[1]
        
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(n_features)]
        
        # 数据标准化
        X_scaled = self.scaler_x.fit_transform(X)
        Z_scaled = self.scaler_z.fit_transform(Z)
        
        # ===== 第一阶段:内生变量对工具变量回归 =====
        # T = Z * α + X * β + ε
        XZ = np.hstack([Z_scaled, X_scaled])
        
        print("Fitting first stage...")
        self.stage1_model.fit(XZ, T)
        
        # 预测T_hat
        T_hat = self.stage1_model.predict(XZ)
        
        # 弱工具变量检验(第一阶段F统计量)
        if self.check_weak_iv:
            f_stat = self._compute_first_stage_f_stat(
                T, Z_scaled, X_scaled, T_hat
            )
            if f_stat < 10:
                warnings.warn(
                    f"弱工具变量警告: 第一阶段F统计量 = {f_stat:.2f} (<10). "
                    "估计可能存在严重偏差。"
                )
            self.first_stage_f_stat_ = f_stat
        
        # ===== 第二阶段:结果变量对预测的处理变量回归 =====
        # y = T_hat * γ + X * δ + u
        print("Fitting second stage...")
        self.stage2_model.fit(np.hstack([T_hat.reshape(-1, 1), X_scaled]), y)
        
        # 提取系数:第一个是处理效应
        self.coef_ = self.stage2_model.coef_[0]
        self.intercept_ = self.stage2_model.intercept_
        
        # 计算标准误
        if self.robust_se:
            self.se_, self.pvalue_ = self._compute_robust_se(
                X_scaled, Z_scaled, T, T_hat, y
            )
        else:
            # 传统标准误(不推荐)
            self.se_ = np.sqrt(np.diag(np.linalg.inv(
                X_scaled.T @ X_scaled
            ))[0])
            self.pvalue_ = 2 * (1 - stats.norm.cdf(abs(self.coef_ / self.se_)))
        
        # 模型解释性
        self.feature_names = feature_names
        self._calculate_stage1_partial_r2(T, T_hat)
        
        return self
    
    def _compute_first_stage_f_stat(self, T: np.ndarray, Z: np.ndarray,
                                     X: np.ndarray, T_hat: np.ndarray) -> float:
        """
        计算第一阶段F统计量
        
        F = (R²_完整 - R²_受限) / n_instruments / (1 - R²_完整) / (n - k)
        """
        # 受限模型:仅在X上回归
        model_restricted = LinearRegression()
        model_restricted.fit(X, T)
        T_hat_restricted = model_restricted.predict(X)
        
        # 计算R²
        ss_res_full = np.sum((T - T_hat) ** 2)
        ss_res_restricted = np.sum((T - T_hat_restricted) ** 2)
        ss_total = np.sum((T - np.mean(T)) ** 2)
        
        r2_full = 1 - ss_res_full / ss_total
        r2_restricted = 1 - ss_res_restricted / ss_total
        
        n, k = X.shape[0], X.shape[1] + Z.shape[1]
        
        f_stat = ((ss_res_restricted - ss_res_full) / Z.shape[1]) / \
                 (ss_res_full / (n - k))
        
        return f_stat
    
    def _compute_robust_se(self, X: np.ndarray, Z: np.ndarray,
                          T: np.ndarray, T_hat: np.ndarray,
                          y: np.ndarray) -> Tuple[float, float]:
        """
        计算稳健标准误(异方差稳健)
        基于sandwich estimator
        """
        n = X.shape[0]
        
        # 第二阶段残差
        X_with_T = np.hstack([T_hat.reshape(-1, 1), X])
        residuals = y - self.stage2_model.predict(X_with_T)
        
        # 梯度矩阵
        # 注意:这仅是近似,精确计算需2SLS特定的梯度
        grad = -X_with_T * residuals.reshape(-1, 1)
        
        # Meat矩阵
        meat = grad.T @ grad / n
        
        # Bread矩阵
        bread = X_with_T.T @ X_with_T / n
        
        # Sandwich estimator
        vcov = np.linalg.inv(bread) @ meat @ np.linalg.inv(bread) / n
        
        se = np.sqrt(np.diag(vcov))[0]
        pvalue = 2 * (1 - stats.norm.cdf(abs(self.coef_ / se)))
        
        return se, pvalue
    
    def _calculate_stage1_partial_r2(self, T: np.ndarray, T_hat: np.ndarray):
        """计算第一阶段Partial R²"""
        ss_res = np.sum((T - T_hat) ** 2)
        ss_total = np.sum((T - np.mean(T)) ** 2)
        self.partial_r2_ = 1 - ss_res / ss_total
        
        return self.partial_r2_
    
    def predict_causal_effect(self, X: np.ndarray) -> np.ndarray:
        """
        预测因果效应(CATE)
        
        对于新样本,假设第一阶段关系稳定
        """
        X_scaled = self.scaler_x.transform(X)
        
        # 使用平均的T_hat(简化处理)
        # 更精确的做法是获取每个样本的T_hat
        cate = np.full(X.shape[0], self.coef_)
        
        return cate
    
    def get_summary(self) -> dict:
        """获取模型摘要"""
        summary = {
            "treatment_effect": self.coef_,
            "std_error": self.se_,
            "p_value": self.pvalue_,
            "ci_95": [
                self.coef_ - 1.96 * self.se_,
                self.coef_ + 1.96 * self.se_
            ],
            "first_stage_f_stat": getattr(self, "first_stage_f_stat_", None),
            "first_stage_partial_r2": getattr(self, "partial_r2_", None),
            "is_weak_iv": getattr(self, "first_stage_f_stat_", 999) < 10,
            "n_features": len(self.feature_names)
        }
        
        # 效应解释
        if summary["p_value"] < 0.05:
            summary["significance"] = "显著"
            summary["interpretation"] = f"价格每提升1单位,销量变化{self.coef_:.3f}单位"
        else:
            summary["significance"] = "不显著"
            summary["interpretation"] = "价格对销量无显著因果影响"
        
        return summary

# 实例应用
def estimate_price_elasticity(df: pd.DataFrame) -> dict:
    """
    估计价格弹性
    
    数据要求:
    - df["price"]: 价格(内生变量)
    - df["cost"]: 成本(工具变量)
    - df["sales"]: 销量(结果变量)
    - df["search_volume", "benchmark_price", ...]: 控制变量
    """
    # 准备数据
    y = df["sales"].values
    T = df["price"].values
    Z = df[["cost"]].values  # 工具变量
    
    # 控制变量(剔除内生变量)
    control_cols = [c for c in df.columns if c not in ["sales", "price", "cost"]]
    X = df[control_cols].values
    
    # 拟合2SLS
    iv_model = TwoStageLeastSquares(
        check_weak_iv=True,
        robust_se=True
    )
    
    iv_model.fit(X, y, Z, T, feature_names=control_cols)
    
    # 获取结果
    summary = iv_model.get_summary()
    
    # 转换为弹性(更直观的业务解释)
    mean_price = np.mean(T)
    mean_sales = np.mean(y)
    elasticity = summary["treatment_effect"] * (mean_price / mean_sales)
    
    summary["price_elasticity"] = elasticity
    summary["elasticity_interpretation"] = f"价格提升1%,销量变化{elasticity:.3f}%"
    
    return summary

# 执行分析
# df_train = pipeline.build_training_dataset("2024-06-15")
# result = estimate_price_elasticity(df_train)
# print(json.dumps(result, indent=2))

3.5 多策略融合的鲁棒性估计

生产环境不能依赖单一识别策略,需构建鲁棒因果融合模型

# robust_estimation.py
from scipy.optimize import minimize
from sklearn.ensemble import GradientBoostingRegressor
import warnings

class RobustCausalFusion:
    """
    鲁棒因果效应融合估计器
    
    集成多种识别策略,通过加权平均降低单一策略失效风险
    权重基于以下原则:
    1. 假设满足度( violation score)
    2. 估计方差(方差越小权重越大)
    3. 策略多样性(避免同质化策略)
    """
    
    def __init__(self, strategies: dict):
        """
        参数:
            strategies: {
                "iv": {"estimator": iv_model, "violation_score": 0.1},
                "did": {"estimator": did_model, "violation_score": 0.3},
                "psm": {"estimator": psm_model, "violation_score": 0.2}
            }
        """
        self.strategies = strategies
        self.weights = None
        self.fused_estimate = None
        
        # 元学习器(用于权重优化)
        self.meta_learner = GradientBoostingRegressor(
            n_estimators=50,
            max_depth=3,
            learning_rate=0.1
        )
    
    def fit(self, X: np.ndarray, y: np.ndarray, 
            treatment: np.ndarray, **kwargs):
        """
        训练所有子策略并计算融合权重
        
        kwargs包含各策略需要的额外参数:
        - iv_Z: 工具变量
        - did_panel: 面板数据
        """
        estimates = []
        variances = []
        violation_scores = []
        
        for name, config in self.strategies.items():
            print(f"Fitting {name} strategy...")
            
            try:
                if name == "iv":
                    strategy_result = self._fit_iv(
                        config["estimator"], X, y, 
                        kwargs["iv_Z"], treatment
                    )
                elif name == "did":
                    strategy_result = self._fit_did(
                        config["estimator"], kwargs["did_panel"]
                    )
                elif name == "psm":
                    strategy_result = self._fit_psm(
                        config["estimator"], X, y, treatment
                    )
                
                estimates.append(strategy_result["effect"])
                variances.append(strategy_result["variance"])
                violation_scores.append(config["violation_score"])
                
            except Exception as e:
                warnings.warn(f"{name} 策略训练失败: {e}")
                continue
        
        self.estimates_ = np.array(estimates)
        self.variances_ = np.array(variances)
        self.violation_scores_ = np.array(violation_scores)
        
        # 计算最优权重
        self.weights = self._compute_optimal_weights(
            self.estimates_, self.variances_, self.violation_scores_
        )
        
        # 融合估计
        self.fused_estimate = np.sum(self.weights * self.estimates_)
        
        return self
    
    def _fit_iv(self, estimator, X, y, Z, T) -> dict:
        """工具变量估计"""
        estimator.fit(X, y, Z, T)
        summary = estimator.get_summary()
        
        return {
            "effect": summary["treatment_effect"],
            "variance": summary["std_error"] ** 2,
            "model": estimator
        }
    
    def _fit_did(self, estimator, panel_data: pd.DataFrame) -> dict:
        """双重差分估计"""
        # DID模型公式:
        # Y_it = α + β·Treat_i + γ·Post_t + δ·(Treat_i×Post_t) + ε_it
        
        # 添加交互项
        panel_data["treated_post"] = panel_data["treated"] * panel_data["post"]
        
        # 固定效应回归(使用linearmodels包更优,此处简化)
        X = panel_data[["treated", "post", "treated_post"]]
        X = sm.add_constant(X)
        y = panel_data["sales_volume"]
        
        model = sm.OLS(y, X).fit(cov_type="clustered",
                                 clusters=panel_data["sku_id"])
        
        # 提取交互项系数(即因果效应)
        did_effect = model.params["treated_post"]
        did_se = model.bse["treated_post"]
        
        return {
            "effect": did_effect,
            "variance": did_se ** 2,
            "model": model
        }
    
    def _fit_psm(self, estimator, X, y, T) -> dict:
        """倾向得分匹配"""
        # 计算倾向得分
        ps = estimator.compute_propensity(X)
        
        # 匹配(此处简化,实际需实现匹配算法)
        matched_indices = estimator.match(ps, T)
        
        # 计算匹配后的ATT
        treated_outcome = y[T == 1][matched_indices["treated"]]
        control_outcome = y[T == 0][matched_indices["control"]]
        
        psm_effect = np.mean(treated_outcome) - np.mean(control_outcome)
        
        # 自助法估计方差
        n_boot = 200
        boot_effects = []
        
        for _ in range(n_boot):
            boot_treated = np.random.choice(treated_outcome, size=len(treated_outcome), replace=True)
            boot_control = np.random.choice(control_outcome, size=len(control_outcome), replace=True)
            boot_effects.append(np.mean(boot_treated) - np.mean(boot_control))
        
        psm_variance = np.var(boot_effects)
        
        return {
            "effect": psm_effect,
            "variance": psm_variance,
            "model": estimator
        }
    
    def _compute_optimal_weights(self, estimates: np.ndarray,
                                 variances: np.ndarray,
                                 violation_scores: np.ndarray) -> np.ndarray:
        """
        计算最优融合权重
        
        优化目标:
        minimize w'Σw + λ·w'Vw
        s.t. Σw_i = 1, w_i ≥ 0
        
        其中:
        - Σ是估计协方差矩阵(对角线为方差)
        - V是假设违背惩罚矩阵(对角线为violation_score)
        """
        n_strategies = len(estimates)
        
        # 构建协方差矩阵(假设策略间独立,简化处理)
        covariance = np.diag(variances)
        
        # 惩罚矩阵
        penalty_matrix = np.diag(violation_scores)
        
        # 组合目标矩阵
        lambda_penalty = 0.5  # 调参
        objective_matrix = covariance + lambda_penalty * penalty_matrix
        
        # 优化目标
        def objective(w):
            return w @ objective_matrix @ w
        
        # 约束条件
        constraints = [
            {"type": "eq", "fun": lambda w: np.sum(w) - 1}
        ]
        bounds = [(0, 1) for _ in range(n_strategies)]
        
        # 初始权重(方差反比加权)
        initial_weights = (1 / variances) / np.sum(1 / variances)
        
        # 求解
        result = minimize(objective, initial_weights, 
                         method="SLSQP", constraints=constraints,
                         bounds=bounds)
        
        if not result.success:
            warnings.warn(f"权重优化失败: {result.message}")
            return initial_weights
        
        return result.x
    
    def get_fused_summary(self) -> dict:
        """获取融合估计摘要"""
        if self.fused_estimate is None:
            raise ValueError("需先调用fit()")
        
        # 计算融合标准误
        fused_variance = np.sum((self.weights ** 2) * self.variances_)
        fused_se = np.sqrt(fused_variance)
        
        # 策略贡献度
        contributions = self.weights * self.estimates_ / self.fused_estimate
        
        summary = {
            "fused_effect": self.fused_estimate,
            "fused_std_error": fused_se,
            "fused_ci_95": [
                self.fused_estimate - 1.96 * fused_se,
                self.fused_estimate + 1.96 * fused_se
            ],
            "individual_estimates": dict(zip(
                self.strategies.keys(), self.estimates_
            )),
            "weights": dict(zip(self.strategies.keys(), self.weights)),
            "strategy_contributions": dict(zip(
                self.strategies.keys(), contributions
            )),
            "robustness_score": 1 - np.std(self.estimates_) / abs(self.fused_estimate)
        }
        
        return summary

# 实际应用
fusion_model = RobustCausalFusion({
    "iv": {
        "estimator": TwoStageLeastSquares(),
        "violation_score": 0.1,  # 成本工具变量质量高
    },
    "did": {
        "estimator": "did_model",
        "violation_score": 0.3,  # 大促期间平行趋势可能弱
    },
    "psm": {
        "estimator": "psm_model",
        "violation_score": 0.2
    }
})

# 融合估计
# fusion_model.fit(X, y, treatment, iv_Z=Z, did_panel=panel_df)
# print(fusion_model.get_fused_summary())

IV. 生产级部署架构

4.1 微服务化与API设计

将因果引擎封装为独立微服务,支持多业务线调用:

# causal_api_service.py
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
import numpy as np
import pickle
import redis
from datetime import datetime

# 数据模型
class EstimateRequest(BaseModel):
    sku_id: str = Field(..., description="商品唯一标识")
    strategy: str = Field("robust_fusion", 
                         description="识别策略: iv/did/psm/robust_fusion")
    features: Dict[str, float] = Field(..., description="商品特征字典")
    price: float = Field(..., gt=0, description="当前价格")
    cost: float = Field(..., gt=0, description="成本(仅IV策略需要)")

class BatchEstimateRequest(BaseModel):
    skus: List[EstimateRequest] = Field(..., min_items=1, max_items=100)

class EstimateResponse(BaseModel):
    sku_id: str
    price_elasticity: float
    confidence_interval: List[float]
    strategy_used: str
    assumption_violation_score: float
    estimated_lift: Optional[float] = None

# API服务
app = FastAPI(title="因果推断定价引擎", version="1.0.0")

# 依赖注入:获取模型和缓存
def get_causal_models():
    return app.state.models

def get_redis_client():
    return app.state.redis

@app.on_event("startup")
def load_models():
    """启动时加载预训练模型"""
    app.state.redis = redis.Redis(
        host="redis-cluster.internal", 
        port=6379, 
        db=0,
        decode_responses=True
    )
    
    # 加载离线训练的三个策略模型
    models = {
        "iv": pickle.load(open("/models/iv_model.pkl", "rb")),
        "did": pickle.load(open("/models/did_model.pkl", "rb")),
        "psm": pickle.load(open("/models/psm_model.pkl", "rb")),
        "fusion": pickle.load(open("/models/robust_fusion.pkl", "rb"))
    }
    
    app.state.models = models
    print("✓ 因果模型加载完成")

@app.post("/api/v1/causal/estimate", response_model=EstimateResponse)
async def estimate_price_effect(
    request: EstimateRequest,
    models: dict = Depends(get_causal_models),
    cache: redis.Redis = Depends(get_redis_client)
):
    """
    单SKU价格因果效应估计
    
    实现要点:
    1. 缓存热点SKU结果(TTL=1小时)
    2. 并发假设检验监控
    3. 降级策略(若模型失败返回经验弹性)
    """
    # 缓存键
    cache_key = f"elasticity:{request.sku_id}:{request.price:.2f}"
    
    # 检查缓存
    cached = cache.get(cache_key)
    if cached:
        return pickle.loads(cached)
    
    try:
        # 构造特征向量(保持与训练时顺序一致)
        feature_order = [
            "log_search_volume", "click_to_conversion_rate",
            "benchmark_price", "seasonal_index", "available_stock"
        ]
        
        X = np.array([request.features.get(f, 0) for f in feature_order]).reshape(1, -1)
        
        # 根据请求策略选择模型
        if request.strategy == "iv":
            model = models["iv"]
            # IV需要额外传入工具变量
            Z = np.array([[request.cost]])  # 成本作为IV
            effect = model.predict_causal_effect(X, Z)
        elif request.strategy == "robust_fusion":
            model = models["fusion"]
            effect = model.predict_fused_effect(X)
        else:
            raise HTTPException(status_code=400, 
                              detail=f"不支持的策略: {request.strategy}")
        
        # 假设检验监控(异步)
        # background_tasks.add_task(
        #     run_assumption_checks, request.sku_id, X, request.price
        # )
        
        # 构建响应
        summary = model.get_summary()
        response = EstimateResponse(
            sku_id=request.sku_id,
            price_elasticity=summary.get("price_elasticity", effect[0]),
            confidence_interval=summary.get("ci_95", [-0.1, 0.1]),
            strategy_used=request.strategy,
            assumption_violation_score=summary.get("violation_score", 0.0),
            estimated_lift=None
        )
        
        # 写入缓存(TTL=3600秒)
        cache.setex(cache_key, 3600, pickle.dumps(response))
        
        return response
        
    except Exception as e:
        # 降级策略:返回品类平均弹性
        category_avg = cache.get(f"category_avg:{request.sku_id[:3]}")
        if category_avg:
            avg_elasticity = float(category_avg)
        else:
            avg_elasticity = -1.5  # 默认弹性
        
        return EstimateResponse(
            sku_id=request.sku_id,
            price_elasticity=avg_elasticity,
            confidence_interval=[-2.0, -1.0],
            strategy_used="FALLBACK",
            assumption_violation_score=1.0,
            estimated_lift=None
        )

@app.post("/api/v1/causal/batch_estimate")
async def batch_estimate(
    request: BatchEstimateRequest,
    models: dict = Depends(get_causal_models)
):
    """批量估计(用于定时调度)"""
    # 批量预测优化:使用矩阵运算而非循环
    # 省略具体实现...
    pass

@app.post("/api/v1/experiment/validate")
async def validate_causal_assumptions(
    experiment_id: str,
    start_date: str,
    end_date: str,
    cache: redis.Redis = Depends(get_redis_client)
):
    """
    实验假设验证端点
    
    在生产调价实验前,验证关键假设:
    1. 工具变量相关性
    2. 协变量平衡性
    3. 重叠假设
    """
    monitor = CausalAssumptionMonitor(cache, alert_webhook=config.ALERT_WEBHOOK)
    
    # 获取实验数据
    df = pipeline.construct_panel_data(start_date) \
                  .filter(col("experiment_id") == experiment_id) \
                  .toPandas()
    
    # 运行检验
    results = monitor.run_comprehensive_check({
        "propensity_scores": {
            "treated": df[df["treated"]==1]["ps"].tolist(),
            "control": df[df["treated"]==0]["ps"].tolist()
        },
        "features": ["search_volume", "benchmark_price"]
    })
    
    # 若假设不满足,阻止实验启动
    if results["overall_risk"] == "HIGH":
        raise HTTPException(
            status_code=400,
            detail=f"假设检验失败: {results}"
        )
    
    return {"status": "验证通过", "results": results}

@app.get("/health")
def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "models_loaded": len(app.state.models),
        "timestamp": datetime.now().isoformat()
    }

# 运行服务
# uvicorn causal_api_service:app --host 0.0.0.0 --port 8000 --workers 4

4.2 Docker化与CI/CD流程

# Dockerfile
FROM python:3.9-slim

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc g++ libgomp1 libopenblas-dev \
    curl \
    && rm -rf /var/lib/apt/lists/*

WORKDIR /app

# 安装Python依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY causal_api_service.py /app/
COPY bandit_algorithms.py /app/
 COPY iv_estimation.py /app/
COPY monitoring.py /app/

# 创建模型目录
RUN mkdir -p /models

# 健康检查
HEALTHCHECK --interval=30s --timeout=3s \
  CMD curl -f http://localhost:8000/health || exit 1

# 运行
EXPOSE 8000
CMD ["uvicorn", "causal_api_service:app", "--host", "0.0.0.0", 
     "--port", "8000", "--workers", "4"]

# .github/workflows/deploy.yml
name: Deploy Causal Inference Service

on:
  push:
    branches: [main]
    paths:
      - 'causal_api_service.py'
      - 'requirements.txt'

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      
      - name: Set up Python
        uses: actions/setup-python@v4
        with:
          python-version: '3.9'
      
      - name: Run unit tests
        run: |
          pip install -r requirements.txt
          pytest tests/ -v
      
      - name: Integration test on synthetic data
        run: |
          python tests/integration_test.py --data-size 10000
      
  build-and-push:
    needs: test
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      
      - name: Build Docker image
        run: |
          docker build -t causal-engine:${{ github.sha }} .
      
      - name: Push to registry
        run: |
          echo "${{ secrets.REGISTRY_PASSWORD }}" | docker login -u "${{ secrets.REGISTRY_USER }}" --password-stdin registry.company.com
          docker tag causal-engine:${{ github.sha }} registry.company.com/causal-engine:${{ github.sha }}
          docker push registry.company.com/causal-engine:${{ github.sha }}

  deploy-to-k8s:
    needs: build-and-push
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      
      - name: Deploy to production
        run: |
          kubectl set image deployment/causal-engine \
            causal-engine=registry.company.com/causal-engine:${{ github.sha }} \
            -n production
      
      - name: Wait for rollout
        run: |
          kubectl rollout status deployment/causal-engine -n production --timeout=300s
      
      - name: Smoke test
        run: |
          python tests/smoke_test.py --endpoint=https://causal-api.company.com

# Kubernetes部署配置
# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: causal-engine
  namespace: production
spec:
  replicas: 5
  selector:
    matchLabels:
      app: causal-engine
  template:
    metadata:
      labels:
        app: causal-engine
    spec:
      containers:
      - name: causal-engine
        image: registry.company.com/causal-engine:latest
        ports:
        - containerPort: 8000
        env:
        - name: REDIS_HOST
          valueFrom:
            secretKeyRef:
              name: redis-secret
              key: host
        - name: MODEL_PATH
          value: "/models"
        resources:
          requests:
            cpu: "2000m"
            memory: "4Gi"
          limits:
            cpu: "4000m"
            memory: "8Gi"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5
        volumeMounts:
        - name: model-storage
          mountPath: /models
          readOnly: true
      
      volumes:
      - name: model-storage
        persistentVolumeClaim:
          claimName: causal-models-pvc
      
      nodeSelector:
        workload-type: compute-intensive
      
      tolerations:
      - key: "compute-intensive"
        operator: "Equal"
        value: "true"
        effect: "NoSchedule"
---
apiVersion: v1
kind: Service
metadata:
  name: causal-engine-service
  namespace: production
spec:
  selector:
    app: causal-engine
  ports:
  - protocol: TCP
    port: 80
    targetPort: 8000
  type: ClusterIP
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
  name: causal-engine-ingress
  namespace: production
  annotations:
    nginx.ingress.kubernetes.io/rate-limit: "1000"
    nginx.ingress.kubernetes.io/rate-limit-window: "1m"
spec:
  rules:
  - host: causal-api.company.com
    http:
      paths:
      - path: /
        pathType: Prefix
        backend:
          service:
            name: causal-engine-service
            port:
              number: 80
  tls:
  - hosts:
    - causal-api.company.com
    secretName: causal-tls-secret
数据流
生产部署架构
Spark流水线
ODS数据
特征仓库
离线模型训练
模型发布
GitHub Actions
Git Push
单元测试
集成测试
Docker Build
Push镜像到Harbor
kubectl rollout
K8s Pod
Redis Cluster
PVC模型存储
监控Prometheus
API Gateway
用户流量
Rate Limit
告警AlertManager
钉钉/Slack

4.3 监控与可观测性

# monitoring_dashboard.py
from prometheus_client import Counter, Histogram, Gauge
import time

# 定义监控指标
PREDICTION_LATENCY = Histogram(
    'causal_prediction_latency_seconds',
    '预测延迟',
    buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 2.0)
)

PREDICTION_COUNT = Counter(
    'causal_predictions_total',
    '总预测次数',
    ['strategy', 'is_cached']
)

ASSUMPTION_VIOLATIONS = Counter(
    'causal_assumption_violations_total',
    '假设违背次数',
    ['assumption_type', 'severity']
)

MODEL_FRESHNESS = Gauge(
    'causal_model_freshness_hours',
    '模型距上次训练时间(小时)'
)

ELASTICITY_DISTRIBUTION = Histogram(
    'price_elasticity_distribution',
    '价格弹性分布',
    buckets=(-5.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0, 0.5, 1.0)
)

@app.middleware("http")
async def add_prometheus_metrics(request: Request, call_next):
    """自动记录指标中间件"""
    start_time = time.time()
    
    response = await call_next(request)
    
    latency = time.time() - start_time
    PREDICTION_LATENCY.observe(latency)
    
    return response

class CausalMetricsCollector:
    """因果推断业务指标收集器"""
    
    def __init__(self, redis_client: redis.Redis):
        self.redis = redis_client
        self.pipeline = redis_client.pipeline()
    
    def log_prediction(self, sku_id: str, elasticity: float,
                       strategy: str, is_cached: bool):
        """记录预测日志"""
        PREDICTION_COUNT.labels(
            strategy=strategy,
            is_cached=str(is_cached)
        ).inc()
        
        ELASTICITY_DISTRIBUTION.observe(elasticity)
        
        # 写入Redis用于实时看板
        self.pipeline.lpush(
            f"elasticity_history:{sku_id}",
            pickle.dumps({
                "elasticity": elasticity,
                "timestamp": time.time(),
                "strategy": strategy
            })
        )
        self.pipeline.ltrim(f"elasticity_history:{sku_id}", 0, 999)
        
        # 聚合统计
        self.pipeline.zadd("elasticity_sorted_set", {sku_id: elasticity})
        
        self.pipeline.execute()
    
    def log_assumption_violation(self, assumption: str, 
                                 severity: str, details: dict):
        """记录假设违背"""
        ASSUMPTION_VIOLATIONS.labels(
            assumption_type=assumption,
            severity=severity
        ).inc()
        
        # 写入告警队列
        self.redis.publish(
            "causal_alerts",
            json.dumps({
                "type": "assumption_violation",
                "assumption": assumption,
                "severity": severity,
                "details": details,
                "timestamp": datetime.now().isoformat()
            })
        )
    
    def update_model_freshness(self, last_train_time: datetime):
        """更新模型新鲜度"""
        hours_since_train = (datetime.now() - last_train_time).total_seconds() / 3600
        MODEL_FRESHNESS.set(hours_since_train)
        
        if hours_since_train > 24:
            # 触发模型重训练
            self.redis.publish("causal_model_refresh", "trigger_retrain")

# Grafana看板配置(JSON格式,可导入)
GRAFANA_DASHBOARD_CONFIG = {
    "dashboard": {
        "title": "因果推断引擎监控",
        "panels": [
            {
                "title": "预测QPS",
                "type": "stat",
                "targets": [{
                    "expr": "rate(causal_predictions_total[5m])"
                }]
            },
            {
                "title": "价格弹性分布",
                "type": "heatmap",
                "targets": [{
                    "expr": "causal_price_elasticity_distribution_bucket"
                }]
            },
            {
                "title": "假设违背告警",
                "type": "table",
                "targets": [{
                    "expr": "increase(causal_assumption_violations_total[1h]) > 0"
                }]
            },
            {
                "title": "模型新鲜度",
                "type": "gauge",
                "targets": [{
                    "expr": "causal_model_freshness_hours"
                }],
                "thresholds": {
                    "mode": "absolute",
                    "steps": [
                        {"value": 0, "color": "green"},
                        {"value": 24, "color": "yellow"},
                        {"value": 48, "color": "red"}
                    ]
                }
            }
        ]
    }
}

V. 实验验证与效果评估

5.1 在线A/B测试框架

因果推断模型上线前必须通过随机实验验证

# ab_testing_framework.py
import hashlib
import random
from datetime import datetime, timedelta
import pandas as pd

class CausalABTestFramework:
    """
    因果推断验证A/B测试框架
    
    与传统A/B测试的区别:
    1. 测试对象是"因果效应估计准确性",而非"策略效果"
    2. 需要构造反事实场景进行验证
    3. 使用"折扣coupon"作为验证工具
    """
    
    def __init__(self, redis_client, pricing_service):
        self.redis = redis_client
        self.pricing_service = pricing_service
        self.experiment_id = None
    
    def create_validation_experiment(self, sku_list: List[str],
                                     n_periods: int = 14) -> str:
        """
        创建验证实验
        
        设计思路:
        - 对每个SKU,在N天内随机分配"价格提升日"和"价格降低日"
        - 用真实的销量变化验证因果模型预测
        - 通过coupon实现价格变动,避免伤害用户体验
        """
        experiment_id = f"causal_val_{datetime.now().strftime('%Y%m%d%H%M%S')}"
        
        for sku in sku_list:
            # 生成随机处理序列
            treatment_days = sorted(random.sample(range(n_periods), n_periods // 2))
            
            self.redis.hset(
                f"causal_val:{experiment_id}:skus",
                sku,
                json.dumps({
                    "treatment_days": treatment_days,
                    "base_price": None,  # 实验开始时记录
                    "actual_lift": []    # 真实销量变化
                })
            )
        
        self.experiment_id = experiment_id
        return experiment_id
    
    def assign_price_treatment(self, sku_id: str, day: int) -> Tuple[float, str]:
        """
        分配当日价格处理
        
        返回:(调整后的价格, treatment_type)
        treatment_type: "increase" / "decrease" / "control"
        """
        experiment_id = self.experiment_id
        
        # 获取SKU的实验配置
        config = json.loads(
            self.redis.hget(f"causal_val:{experiment_id}:skus", sku_id)
        )
        
        if day in config["treatment_days"]:
            # 随机选择提价或降价
            if random.random() > 0.5:
                adjustment = 1.1  # 提价10%
                treatment_type = "increase"
            else:
                adjustment = 0.9  # 降价10%
                treatment_type = "decrease"
        else:
            adjustment = 1.0
            treatment_type = "control"
        
        # 记录当日价格
        base_price = config["base_price"]
        if base_price is None:
            # 首次调用,从定价服务获取
            base_price = self.pricing_service.get_current_price(sku_id)
            config["base_price"] = base_price
        
        adjusted_price = base_price * adjustment
        
        return adjusted_price, treatment_type
    
    def record_outcome(self, sku_id: str, day: int,
                       sales_volume: int, price: float):
        """记录实验结果"""
        config = json.loads(
            self.redis.hget(f"causal_val:{experiment_id}:skus", sku_id)
        )
        
        # 计算相对于基准的销量变化
        if len(config["actual_lift"]) == 0:
            baseline_volume = sales_volume  # 第一天作为基准
        else:
            baseline_volume = config["actual_lift"][0]["baseline"]
        
        lift = (sales_volume - baseline_volume) / baseline_volume
        
        config["actual_lift"].append({
            "day": day,
            "price": price,
            "sales_volume": sales_volume,
            "lift": lift,
            "timestamp": datetime.now().isoformat()
        })
        
        self.redis.hset(
            f"causal_val:{experiment_id}:skus",
            sku_id,
            json.dumps(config)
        )
    
    def evaluate_model_accuracy(self, experiment_id: str,
                                model_predictions: pd.DataFrame) -> dict:
        """
        评估模型预测准确性
        
        model_predictions格式:
        sku_id | day | predicted_lift | predicted_ci_lower | predicted_ci_upper
        
        评估指标:
        1. 预测误差MAE
        2. 置信区间覆盖率
        3. 排序相关性(预测效应 vs 真实效应)
        """
        results = {
            "experiment_id": experiment_id,
            "sku_evaluations": {},
            "overall_metrics": {}
        }
        
        all_errors = []
        all_coverage = []
        
        for sku in model_predictions["sku_id"].unique():
            # 获取真实结果
            config = json.loads(
                self.redis.hget(f"causal_val:{experiment_id}:skus", sku)
            )
            
            actual_df = pd.DataFrame(config["actual_lift"])
            
            # 合并预测与实际
            merged = pd.merge(
                model_predictions[model_predictions["sku_id"] == sku],
                actual_df[["day", "lift"]],
                on="day",
                how="inner"
            )
            
            if merged.empty:
                continue
            
            # 计算误差
            merged["error"] = merged["predicted_lift"] - merged["lift"]
            mae = np.mean(np.abs(merged["error"]))
            mape = np.mean(np.abs(merged["error"] / merged["lift"]))
            
            # 置信区间覆盖率
            coverage = np.mean(
                (merged["predicted_ci_lower"] <= merged["lift"]) &
                (merged["predicted_ci_upper"] >= merged["lift"])
            )
            
            results["sku_evaluations"][sku] = {
                "mae": mae,
                "mape": mape,
                "coverage": coverage,
                "n_observations": len(merged)
            }
            
            all_errors.extend(merged["error"].tolist())
            all_coverage.append(coverage)
        
        # 总体指标
        results["overall_metrics"] = {
            "overall_mae": np.mean([v["mae"] for v in results["sku_evaluations"].values()]),
            "overall_coverage": np.mean(all_coverage),
            "calibration_score": 1 - abs(0.95 - np.mean(all_coverage)),  # 接近0.95为佳
            "bias": np.mean(all_errors)  # 平均预测偏差
        }
        
        # 判定模型是否可上线
        results["is_production_ready"] = (
            results["overall_metrics"]["overall_mae"] < 0.05 and
            results["overall_metrics"]["overall_coverage"] > 0.85
        )
        
        return results

# 使用示例
# ab_test = CausalABTestFramework(redis_client, pricing_service)
# exp_id = ab_test.create_validation_experiment(sku_list=["SKU001", "SKU002"])
# 
# # 每日执行
# for day in range(14):
#     for sku in sku_list:
#         price, treatment = ab_test.assign_price_treatment(sku, day)
#         # 应用价格到线上...
#         sales = get_actual_sales(sku, day)
#         ab_test.record_outcome(sku, day, sales, price)
# 
# # 实验结束评估
# predictions = load_model_predictions()
# eval_results = ab_test.evaluate_model_accuracy(exp_id, predictions)
# print(eval_results["is_production_ready"])

5.2 业务效果评估

上线后业务指标追踪

评估维度 指标定义 目标值 实际值(3个月)
销量预测误差 MAPE <8% 6.2%
收入提升 对比人工定价 >5% 8.7%
决策响应时间 P95延迟 <200ms 120ms
假设违反次数 周均告警 <3次 1.4次
模型更新频率 自动化重训练 周级 每5天

案例:SKU-7241智能定价效果

时间段 策略 平均价格 销量 收入 利润率
基线期(人工) 固定毛利率20% 299元 120件/周 35,880元 20%
实验期(因果模型) 动态弹性定价 276元(↓7.7%) 165件/周(↑37.5%) 45,540元(↑27%) 22.5%

关键洞察

  1. 弹性识别:模型发现该SKU价格弹性为-1.8(高弹性),适度降价带来显著增量
  2. 库存协同:避免降价导致的缺货,设置库存阈值触发自动提价
  3. 竞争动态:监控竞品价格,保持5-10%价差优势
模型迭代
业务指标
效果评估闭环
假设校验
离线重训练
模型发布
灰度发布
全量
显著
收入+8.7%
达标
误差6.2%
优秀
响应120ms
稳定
告警1.4次/周
定价策略
模型预测
线上A/B测试
真实效果采集
误差分析

VI. 最佳实践与避坑指南

6.1 数据工程黄金法则

法则 详细说明 常见错误 工程解决方案
时间戳对齐 处理与结果变量必须精确到分钟级时间戳 使用业务日期导致因果倒置 强制timestamp字段,拒绝date类型
缺失机制记录 记录每个缺失值的产生原因 直接删除或插补引入选择偏差 添加_missing_reason列,MAR/MNAR分类
协变量冻结 基线协变量必须在处理分配前采集 处理后采集的变量成为中介 建立基线快照表,分配后只读
工具变量隔离 IV数据独立存储,权限最小化 IV被污染(与结果直接相关) IV源系统独立,访问审计
版本控制 因果图、模型、特征均需版本管理 无法复现历史结果 Git LFS + MLflow + 数据快照

6.2 模型部署反模式

反模式1:离线训练-在线预测不一致

  • 问题:离线使用全量历史数据,在线特征未实时对齐
  • 案例:离线建模包含"用户30天平均支付金额",在线服务只拿到7天数据
  • 解决方案:特征校验层 + 特征重要性监控,实时检测分布偏移
# 特征一致性校验
class FeatureConsistencyValidator:
    def __init__(self, expected_schema: dict):
        self.expected_schema = expected_schema
    
    def validate(self, online_features: dict) -> bool:
        """
        校验在线特征与离线训练 schema 一致性
        
        检查项:
        1. 字段存在性
        2. 类型匹配
        3. 数值范围(5% - 95%分位数)
        """
        for field, spec in self.expected_schema.items():
            if field not in online_features:
                raise ValueError(f"缺失特征: {field}")
            
            value = online_features[field]
            
            # 范围检查
            if not (spec["min"] <= value <= spec["max"]):
                warnings.warn(
                    f"特征 {field} 越界: {value} 不在 [{spec['min']}, {spec['max']}]"
                )
        
        return True

# 使用
validator = FeatureConsistencyValidator({
    "log_search_volume": {"min": -5, "max": 10, "type": float},
    "price_elasticity": {"min": -5, "max": 0, "type": float}
})

反模式2:忽视因果假设的动态变化

  • 问题:业务环境变化导致IV排他性失效(如供应商成本开始直接影响需求)
  • 检测:监控第一阶段F统计量,若持续下降说明IV弱化
  • 应对:自动降级到PSM或DID,触发人工review因果图

反模式3:因果效应的滥用

  • 问题:将群体平均效应(ATE)应用于所有个体
  • 案例:对价格敏感用户和非敏感用户统一定价,导致部分用户流失
  • 解决方案:构建CATE模型,实现个性化定价
# 个体化因果效应(CATE)估计
class CATEEstimator:
    def __init__(self, base_estimator, personalization_features: list):
        self.base_estimator = base_estimator
        self.personalization_features = personalization_features
    
    def estimate_cate(self, X: pd.DataFrame) -> pd.Series:
        """
        估计条件平均因果效应
        
        方法:在基础模型上叠加X-learned meta-learner
        """
        # 1. 基础效应
        base_effect = self.base_estimator.predict_causal_effect(X)
        
        # 2. 个性化偏差(基于因果森林或S-Learner)
        from xgboost import XGBRegressor
        
        meta_model = XGBRegressor(n_estimators=100, max_depth=3)
        meta_model.fit(
            X[self.personalization_features],
            X["observed_lift"] - base_effect  # 残差
        )
        
        personalization_bias = meta_model.predict(
            X[self.personalization_features]
        )
        
        return base_effect + personalization_bias

# 应用
# cate_estimator = CATEEstimator(iv_model, ["user_price_sensitivity_score"])
# personalized_elasticity = cate_estimator.estimate_cate(user_features)

VII. 未来演进方向

7.1 与强化学习的融合

因果推断提供离线世界模型,RL提供在线策略优化,二者结合实现反事实决策

融合方式 技术方案 优势 挑战
离线预训练 用因果模型初始化Q函数 样本效率高 分布偏移处理
在线修正 因果模型作为先验,RL微调 鲁棒性强 超参数敏感
反事实探索 用因果图约束动作空间 安全探索 图结构学习

实例:动态定价的马尔可夫决策过程

# causal_rl_pricing.py
import gym
from stable_baselines3 import PPO

class CausalPricingEnv(gym.Env):
    """
    基于因果模型的定价MDP
    
    状态空间:库存、时间、需求状态
    动作空间:价格调整[-20%, +20%]
    奖励函数:基于因果模型的收入预测
    """
    
    def __init__(self, causal_model, sku_id: str):
        super().__init__()
        self.causal_model = causal_model
        self.sku_id = sku_id
        
        # 状态空间
        self.observation_space = gym.spaces.Box(
            low=np.array([0, 0, 0]),  # [库存, 时间步, 基础需求]
            high=np.array([1000, 100, 1000]),
            dtype=np.float32
        )
        
        # 动作空间
        self.action_space = gym.spaces.Box(
            low=-0.2,
            high=0.2,
            shape=(1,),
            dtype=np.float32
        )
        
        self.current_state = None
    
    def step(self, action):
        """
        执行定价动作
        
        动作:价格调整比例
        状态转移:库存减少,时间步进
        奖励:预测收入(因果模型)
        """
        # 解析动作
        price_adjustment = action[0]
        current_price = self.current_state["base_price"] * (1 + price_adjustment)
        
        # 构造特征
        features = {
            "price": current_price,
            "inventory": self.current_state["inventory"],
            "day_of_week": self.current_state["time"] % 7,
            "search_volume": self.current_state["demand"]
        }
        
        # 使用因果模型预测销量和收入
        elasticity = self.causal_model.estimate_elasticity(features)
        predicted_sales = self.current_state["base_sales"] * \
                         (1 + elasticity * price_adjustment)
        
        # 计算奖励(收入)
        revenue = current_price * predicted_sales
        
        # 状态转移
        next_state = {
            "inventory": max(0, self.current_state["inventory"] - predicted_sales),
            "time": self.current_state["time"] + 1,
            "demand": self._update_demand(self.current_state["demand"]),
            "base_price": self.current_state["base_price"],
            "base_sales": self.current_state["base_sales"]
        }
        
        self.current_state = next_state
        
        done = next_state["inventory"] == 0 or next_state["time"] >= 30
        
        return next_state, revenue, done, {"predicted_sales": predicted_sales}
    
    def reset(self):
        self.current_state = {
            "inventory": 500,
            "time": 0,
            "demand": 100,
            "base_price": 299.0,
            "base_sales": 50
        }
        return self.current_state
    
    def _update_demand(self, current_demand):
        # 简单自回归模型模拟需求变化
        return max(10, current_demand + np.random.normal(0, 5))

# 训练
# env = CausalPricingEnv(causal_model, "SKU-001")
# model = PPO("MlpPolicy", env, verbose=1)
# model.learn(total_timesteps=10000)

# 部署:模型每7天用新数据重新预训练,RL在线适应

7.2 因果发现(Causal Discovery)自动化

从数据自动学习因果图,减少专家依赖:

# causal_discovery.py
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import fisherz
import pandas as pd

class AutomatedCausalDiscovery:
    """
    自动化因果发现
    
    生产级挑战:
    1. 计算复杂度高(O(p²))
    2. 需要条件独立性检验
    3. 等价类问题(无法区分方向)
    
    工程方案:
    - 限制节点数(<50)
    - 使用先验知识约束搜索空间
    - 多算法融合(PC、GES、NOTEARS)
    """
    
    def __init__(self, alpha: float = 0.05, max_nodes: int = 30):
        self.alpha = alpha  # 独立性检验显著性阈值
        self.max_nodes = max_nodes
    
    def discover_from_data(self, df: pd.DataFrame,
                          prior_knowledge: Optional[dict] = None) -> nx.DiGraph:
        """
        从数据发现因果图
        
        参数:
            df: 数据框(列名为变量名)
            prior_knowledge: 先验约束
                {
                    "must_have_edge": [("A", "B")],
                    "forbidden_edge": [("C", "D")],
                    "temporal_order": [["A", "B"], ["C", "D", "E"]]  # 时间顺序
                }
        """
        # 数据维度检查
        if df.shape[1] > self.max_nodes:
            warnings.warn(
                f"变量数 {df.shape[1]} 超过最大限制 {self.max_nodes},"
                f"将使用前{self.max_nodes}个变量"
            )
            df = df.iloc[:, :self.max_nodes]
        
        # 执行PC算法
        cg = pc(
            data=df.values,
            alpha=self.alpha,
            indep_test=fisherz,
            uc_rule=0,  # 规则: 0=原始, 1=循环
            uc_priority=2  # 优先级: 1=可信, 2=距离, 3=加权
        )
        
        # 转换为NetworkX图
        graph = nx.DiGraph()
        graph.add_nodes_from(df.columns)
        
        for i, j in cg.G.get_edges():
            graph.add_edge(df.columns[i], df.columns[j])
        
        # 应用先验知识约束
        if prior_knowledge:
            graph = self._apply_prior_knowledge(graph, prior_knowledge)
        
        return graph
    
    def validate_discovered_graph(self, graph: nx.DiGraph,
                                 validation_data: pd.DataFrame) -> dict:
        """
        验证发现图的可靠性
        
        方法:
        1. 自助法稳定性:多次子采样,检查边稳定性
        2. 预测能力:基于图的预测误差
        3. 专家评估:抽样边进行人工review
        """
        n_bootstrap = 100
        edge_stability = {}
        
        for edge in graph.edges():
            edge_stability[edge] = 0
        
        # 自助抽样
        for _ in range(n_bootstrap):
            sample_df = validation_data.sample(frac=0.8, replace=True)
            bootstrap_graph = self.discover_from_data(sample_df)
            
            for edge in graph.edges():
                if bootstrap_graph.has_edge(*edge):
                    edge_stability[edge] += 1
        
        # 计算稳定性分数
        stability_scores = {
            edge: count / n_bootstrap
            for edge, count in edge_stability.items()
        }
        
        # 整体稳定性
        overall_stability = np.mean(list(stability_scores.values()))
        
        return {
            "overall_stability": overall_stability,
            "edge_stability": stability_scores,
            "recommendation": "accept" if overall_stability > 0.7 \
                              else "review" if overall_stability > 0.5 \
                              else "reject"
        }

# 生产级限制:仅用于辅助,不自动修改生产因果图
# 需人工review稳定性>0.7的边
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。