分布外检测的Score-Based方法:能量函数与似然比的校准理论
分布外检测的Score-Based方法:能量函数与似然比的校准理论
引言:分布外检测的现实挑战与理论演进
在现实世界的机器学习部署中,一个长期存在的挑战是模型在面对与训练数据分布不同的样本时的行为不可预测性。这种分布外(Out-of-Distribution, OOD)检测问题在安全关键领域如医疗诊断、自动驾驶和金融风控中尤为重要。传统的基于softmax置信度的方法已被证明在OOD检测上存在严重缺陷,因为它们往往会对分布外样本给出过度自信的预测。
近年来,基于生成模型的score-based方法为OOD检测提供了新的理论框架。其中,能量函数与似然比校准理论成为了这一领域的核心突破。本文将深入探讨这一理论框架,并通过详细的代码实例展示其实际应用。
理论基础:从生成模型到能量函数
能量函数的数学定义
能量函数的概念来源于能量基模型(Energy-Based Models, EBMs),它将输入样本x映射到一个标量能量值E(x)上,能量越低表示样本越可能来自训练分布。在深度学习中,我们可以将神经网络的输出与能量函数建立联系:
其中f(x)是分类器在softmax前的logits输出,T是温度参数。
似然比校准的理论框架
似然比校准的核心思想是通过对比训练分布与参考分布的似然比来检测OOD样本。假设我们有两个概率分布:p_in(x)表示训练数据分布,p_ref(x)表示一个广泛的参考分布。那么似然比可以表示为:
在实际应用中,我们通过能量函数来近似这个似然比:
方法实现:能量函数的构建与校准
基础能量函数的实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
class EnergyBasedOODDetector:
"""基于能量的OOD检测器"""
def __init__(self, model: nn.Module, temperature: float = 1.0):
"""
初始化能量检测器
参数:
model: 预训练的分类模型
temperature: 温度参数,用于校准能量尺度
"""
self.model = model
self.temperature = temperature
self.model.eval() # 设置为评估模式
def compute_energy_score(self, x: torch.Tensor) -> torch.Tensor:
"""
计算输入样本的能量分数
能量函数定义: E(x) = -T * logsumexp(f(x)/T)
参数:
x: 输入张量,形状为(batch_size, ...)
返回:
能量分数,形状为(batch_size,)
"""
with torch.no_grad():
# 获取模型logits输出
logits = self.model(x)
# 计算能量分数: E(x) = -T * logsumexp(f(x)/T)
energy = -self.temperature * torch.logsumexp(logits / self.temperature, dim=1)
return energy
def compute_ood_score(self, x: torch.Tensor) -> torch.Tensor:
"""
计算OOD检测分数(能量分数的负值)
分数越高,越可能是OOD样本
"""
energy = self.compute_energy_score(x)
return -energy # 负能量作为OOD分数
改进的能量函数:加入扰动分析
class PerturbedEnergyDetector(EnergyBasedOODDetector):
"""加入随机扰动的能量检测器"""
def __init__(self, model: nn.Module, temperature: float = 1.0,
noise_scale: float = 0.05, num_perturbations: int = 10):
super().__init__(model, temperature)
self.noise_scale = noise_scale
self.num_perturbations = num_perturbations
def compute_perturbed_energy(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
计算扰动后的能量统计量
参数:
x: 输入张量
返回:
energy_mean: 平均能量,形状为(batch_size,)
energy_std: 能量标准差,形状为(batch_size,)
"""
batch_size = x.shape[0]
device = x.device
# 存储多次扰动的能量
all_energies = []
for _ in range(self.num_perturbations):
# 添加随机扰动
noise = torch.randn_like(x) * self.noise_scale
x_perturbed = x + noise
# 计算扰动样本的能量
energy = self.compute_energy_score(x_perturbed)
all_energies.append(energy)
# 统计能量的均值和标准差
all_energies = torch.stack(all_energies, dim=0) # [num_perturbations, batch_size]
energy_mean = torch.mean(all_energies, dim=0)
energy_std = torch.std(all_energies, dim=0)
return energy_mean, energy_std
def compute_ood_score_with_uncertainty(self, x: torch.Tensor) -> torch.Tensor:
"""
结合不确定性的OOD检测分数
理论依据: OOD样本通常对扰动更敏感
"""
energy_mean, energy_std = self.compute_perturbed_energy(x)
# 结合能量均值和不确定性
# 分数 = -能量均值 + λ * 能量标准差
lambda_uncertainty = 0.1 # 不确定性权重
ood_score = -energy_mean + lambda_uncertainty * energy_std
return ood_score
似然比校准的深度实现
参考分布的构建与似然比计算
class LikelihoodRatioCalibrator:
"""似然比校准器"""
def __init__(self, in_distribution_model: nn.Module,
reference_model: Optional[nn.Module] = None,
calibration_data: Optional[torch.Tensor] = None):
"""
初始化似然比校准器
参数:
in_distribution_model: 训练分布模型
reference_model: 参考分布模型(如背景模型)
calibration_data: 校准数据,用于拟合后处理函数
"""
self.in_model = in_distribution_model
self.ref_model = reference_model
self.calibration_data = calibration_data
# 校准参数
self.bias = 0.0
self.scale = 1.0
if calibration_data is not None:
self._calibrate_parameters()
def _calibrate_parameters(self):
"""使用校准数据拟合似然比的后处理参数"""
print("校准似然比参数...")
# 计算校准数据的似然比
log_likelihood_in = self._compute_log_likelihood(self.in_model, self.calibration_data)
if self.ref_model is not None:
log_likelihood_ref = self._compute_log_likelihood(self.ref_model, self.calibration_data)
log_likelihood_ratio = log_likelihood_in - log_likelihood_ref
else:
# 如果没有参考模型,使用均匀分布作为参考
num_classes = self.in_model(self.calibration_data[:1]).shape[1]
log_likelihood_ref = torch.log(torch.tensor(1.0 / num_classes))
log_likelihood_ratio = log_likelihood_in - log_likelihood_ref
# 使用简单的统计校准
# 目标是使ID样本的似然比集中在某个值附近
self.bias = -torch.mean(log_likelihood_ratio).item()
self.scale = 1.0 / torch.std(log_likelihood_ratio).item()
print(f"校准完成: bias={self.bias:.4f}, scale={self.scale:.4f}")
def _compute_log_likelihood(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor:
"""计算对数似然"""
model.eval()
with torch.no_grad():
logits = model(x)
# 使用logsumexp计算对数似然
log_likelihood = torch.logsumexp(logits, dim=1)
return log_likelihood
def compute_calibrated_likelihood_ratio(self, x: torch.Tensor) -> torch.Tensor:
"""
计算校准后的似然比
理论: log(p_in(x)/p_ref(x)) ≈ log p_in(x) - log p_ref(x)
"""
# 计算训练分布的对数似然
log_likelihood_in = self._compute_log_likelihood(self.in_model, x)
if self.ref_model is not None:
# 计算参考分布的对数似然
log_likelihood_ref = self._compute_log_likelihood(self.ref_model, x)
log_likelihood_ratio = log_likelihood_in - log_likelihood_ref
else:
# 使用均匀分布作为参考
num_classes = self.in_model(x[:1]).shape[1]
log_likelihood_ref = torch.log(torch.tensor(1.0 / num_classes))
log_likelihood_ratio = log_likelihood_in - log_likelihood_ref
# 应用校准: y = scale * (x + bias)
calibrated_ratio = self.scale * (log_likelihood_ratio + self.bias)
return calibrated_ratio
def detect_ood(self, x: torch.Tensor, threshold: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
OOD检测
参数:
x: 输入样本
threshold: 决策阈值
返回:
scores: OOD分数
is_ood: 布尔张量,True表示OOD样本
"""
# 计算校准后的似然比
lr_scores = self.compute_calibrated_likelihood_ratio(x)
# OOD分数为似然比的负值(似然比越低,越可能是OOD)
ood_scores = -lr_scores
# 决策
is_ood = ood_scores > threshold
return ood_scores, is_ood
高级方法:混合密度与分数匹配
基于分数匹配的能量函数优化
class ScoreMatchingEnergyDetector:
"""基于分数匹配的能量检测器"""
def __init__(self, model: nn.Module, sigma_list: list = None):
"""
初始化分数匹配检测器
参数:
model: 用于学习分数函数的模型
sigma_list: 噪声尺度列表,用于多尺度分数匹配
"""
self.model = model
self.sigma_list = sigma_list or [0.1, 0.3, 0.5, 0.7, 1.0]
def denoising_score_matching_loss(self, x: torch.Tensor) -> torch.Tensor:
"""
去噪分数匹配损失
目标: 学习∇_x log p(x) ≈ s_θ(x)
理论: L(θ) = E_{σ∼p(σ)} E_{x∼p_data} E_{x̃∼N(x,σ²)}[||s_θ(x̃) + (x̃-x)/σ²||²]
"""
batch_size = x.shape[0]
device = x.device
# 随机选择噪声尺度
sigma_idx = torch.randint(0, len(self.sigma_list), (batch_size,))
sigma = torch.tensor([self.sigma_list[i] for i in sigma_idx]).to(device).view(-1, 1, 1, 1)
# 添加噪声
noise = torch.randn_like(x)
x_noisy = x + sigma * noise
# 计算模型预测的分数
score_pred = self.model(x_noisy)
# 目标分数: -噪声/σ²
score_target = -noise / (sigma ** 2 + 1e-8)
# 分数匹配损失
loss = torch.mean(torch.sum((score_pred - score_target) ** 2, dim=[1, 2, 3]))
return loss
def compute_energy_from_score(self, x: torch.Tensor, num_steps: int = 100) -> torch.Tensor:
"""
通过分数函数计算能量
能量可以通过分数的路径积分近似:
E(x) ≈ -∫ s_θ(x(t)) dx(t) + 常数
参数:
x: 输入样本
num_steps: 积分步数
返回:
相对能量(未归一化)
"""
# 使用朗之万动力学从参考分布采样
x_samples = self._langevin_dynamics(x.shape, num_steps=num_steps)
# 计算从参考点到样本点的路径积分
energies = []
for i in range(x.shape[0]):
# 简化计算:使用梯形法则近似路径积分
x_i = x[i:i+1]
x_ref = x_samples[i:i+1]
# 线性插值路径
t_values = torch.linspace(0, 1, 10).to(x.device)
path_integral = 0.0
for j in range(len(t_values)-1):
t1, t2 = t_values[j], t_values[j+1]
x1 = (1-t1) * x_ref + t1 * x_i
x2 = (1-t2) * x_ref + t2 * x_i
# 计算分数
s1 = self.model(x1)
s2 = self.model(x2)
# 路径积分增量
dx = x2 - x1
path_integral += torch.sum(0.5 * (s1 + s2) * dx)
energies.append(-path_integral) # 能量 = -log p(x) ≈ -路径积分
return torch.stack(energies).squeeze()
def _langevin_dynamics(self, shape: torch.Size, num_steps: int = 100) -> torch.Tensor:
"""朗之万动力学采样"""
device = next(self.model.parameters()).device
# 从随机噪声开始
x = torch.randn(shape).to(device) * 2.0
# 朗之万更新
for step in range(num_steps):
# 添加噪声
noise = torch.randn_like(x) * 0.01
# 计算分数
with torch.no_grad():
score = self.model(x)
# 更新样本
x = x + 0.01 * score + noise
# 偶尔添加大噪声避免陷入局部极小
if step % 20 == 0:
x = x + torch.randn_like(x) * 0.1
return x
实验评估与结果分析
完整实验流程示例
class OODDetectionExperiment:
"""OOD检测实验框架"""
def __init__(self, in_dataset, ood_datasets, model):
"""
初始化实验
参数:
in_dataset: 训练分布数据集
ood_datasets: OOD测试数据集列表
model: 预训练模型
"""
self.in_dataset = in_dataset
self.ood_datasets = ood_datasets
self.model = model
# 初始化检测器
self.energy_detector = EnergyBasedOODDetector(model, temperature=1.0)
self.lr_calibrator = LikelihoodRatioCalibrator(model)
def evaluate_detector(self, detector, threshold_type: str = "95%TPR"):
"""
评估检测器性能
参数:
detector: OOD检测器
threshold_type: 阈值选择方法
返回:
评估指标字典
"""
results = {}
# 计算ID样本的分数
id_scores = self._compute_scores(detector, self.in_dataset, is_id=True)
for ood_name, ood_dataset in self.ood_datasets.items():
# 计算OOD样本的分数
ood_scores = self._compute_scores(detector, ood_dataset, is_id=False)
# 计算评估指标
metrics = self._compute_metrics(id_scores, ood_scores, threshold_type)
results[ood_name] = metrics
print(f"OOD数据集: {ood_name}")
print(f" AUROC: {metrics['auroc']:.4f}")
print(f" FPR@95TPR: {metrics['fpr_at_95tpr']:.4f}")
print(f" 检测准确率: {metrics['detection_accuracy']:.4f}")
return results
def _compute_scores(self, detector, dataset, is_id: bool = True):
"""计算数据集上的OOD检测分数"""
scores = []
for batch_idx, (data, _) in enumerate(dataset):
if batch_idx > 100: # 限制评估样本数
break
# 根据检测器类型调用不同方法
if isinstance(detector, EnergyBasedOODDetector):
batch_scores = detector.compute_ood_score(data)
elif isinstance(detector, LikelihoodRatioCalibrator):
batch_scores, _ = detector.detect_ood(data)
else:
raise ValueError(f"不支持的检测器类型: {type(detector)}")
scores.append(batch_scores.cpu().numpy())
return np.concatenate(scores)
def _compute_metrics(self, id_scores, ood_scores, threshold_type: str):
"""计算OOD检测指标"""
from sklearn.metrics import roc_auc_score, roc_curve
# 合并分数和标签
scores = np.concatenate([id_scores, ood_scores])
labels = np.concatenate([np.zeros_like(id_scores), np.ones_like(ood_scores)])
# 计算AUROC
auroc = roc_auc_score(labels, scores)
# 计算FPR@95%TPR
fpr, tpr, thresholds = roc_curve(labels, scores)
fpr_at_95tpr = fpr[np.argmax(tpr >= 0.95)]
# 计算检测准确率
if threshold_type == "95%TPR":
threshold = thresholds[np.argmax(tpr >= 0.95)]
else:
# 默认使用最佳阈值
threshold = thresholds[np.argmax(tpr - fpr)]
predictions = (scores > threshold).astype(int)
detection_accuracy = np.mean(predictions == labels)
return {
"auroc": auroc,
"fpr_at_95tpr": fpr_at_95tpr,
"detection_accuracy": detection_accuracy,
"threshold": threshold
}
def run_comparison_experiment(self):
"""运行比较实验"""
print("=" * 60)
print("OOD检测方法比较实验")
print("=" * 60)
# 评估不同方法
methods = {
"Energy-Based": self.energy_detector,
"Likelihood Ratio": self.lr_calibrator,
}
all_results = {}
for method_name, detector in methods.items():
print(f"\n评估方法: {method_name}")
print("-" * 40)
results = self.evaluate_detector(detector)
all_results[method_name] = results
# 分析比较结果
self._analyze_comparison_results(all_results)
return all_results
def _analyze_comparison_results(self, all_results):
"""分析比较结果"""
print("\n" + "=" * 60)
print("实验结果分析")
print("=" * 60)
# 计算平均性能
for method_name, results in all_results.items():
avg_auroc = np.mean([r["auroc"] for r in results.values()])
avg_fpr = np.mean([r["fpr_at_95tpr"] for r in results.values()])
print(f"{method_name}:")
print(f" 平均AUROC: {avg_auroc:.4f}")
print(f" 平均FPR@95TPR: {avg_fpr:.4f}")
理论深度分析:能量函数与似然比校准的数学基础
能量函数的统计解释
从统计力学角度看,能量函数与概率分布的关系为:
其中Z是配分函数。在OOD检测中,我们实际上关心的是能量的相对值而非绝对值。这解释了为什么简单的能量阈值化在某些情况下比softmax置信度更有效。
似然比校准的理论保证
似然比检验在统计学中具有最优性保证(Neyman-Pearson引理)。对于简单假设检验问题,似然比检验在所有给定第一类错误率(误报率)的检验中,具有最小的第二类错误率(漏报率)。在OOD检测中:
- H0: 样本来自训练分布p_in(x)
- H1: 样本来自OOD分布p_out(x)
似然比检验统计量为:
在实际中,p_out(x)未知,我们使用一个参考分布p_ref(x)来近似。
温度缩放与能量校准
温度参数T在能量函数中起着关键作用,它实际上在进行概率校准:
当T>1时,分布变得更平坦;当T<1时,分布变得更尖锐。在OOD检测中,适当选择T可以改善ID和OOD样本的能量分离度。
实际应用建议与未来方向
应用建议
- 数据预处理的重要性:确保ID和OOD样本的预处理一致
- 参考分布的选择:均匀分布、背景数据集或生成模型都是合理选择
- 阈值设定的策略:使用验证集上的FPR目标来设定阈值
- 不确定性量化的结合:将能量分数与预测不确定性结合可提高鲁棒性
未来研究方向
- 无监督参考分布学习:从训练数据中自动学习参考分布
- 多模态OOD检测:处理复杂、多模态的OOD分布
- 在线自适应校准:在部署中持续校准检测器
- 理论泛化边界:建立能量基OOD检测的泛化理论
结论
基于能量函数与似然比校准的OOD检测方法代表了这一领域的重要理论进展。通过将生成模型的思想与判别模型相结合,这些方法能够更可靠地区分ID和OOD样本。本文提供的代码框架和理论分析为实际应用和进一步研究提供了坚实的基础。
值得注意的是,没有单一的OOD检测方法在所有场景下都是最优的。实际部署时应根据具体应用场景、数据特性和安全要求,选择合适的检测策略并可能结合多种方法。
- 点赞
- 收藏
- 关注作者
评论(0)