深度残差收缩网络的Python实现与抗噪性能分析

举报
伊尔参思 发表于 2026/03/01 15:23:31 2026/03/01
【摘要】 在复杂的工业现场,旋转机械的早期故障信号经常被强烈的随机噪声淹没,传统的深度学习模型在低信噪比环境下的诊断精度可能会不尽如人意。针对这一难题,论文“Deep Residual Shrinkage Networks for Fault Diagnosis”提出了深度残差收缩网络。该研究通过在残差网络(Residual Network, ResNet)中引入可学习的软阈值算子,实现了特征层面的自...

在复杂的工业现场,旋转机械的早期故障信号经常被强烈的随机噪声淹没,传统的深度学习模型在低信噪比环境下的诊断精度可能会不尽如人意。针对这一难题,论文“Deep Residual Shrinkage Networks for Fault Diagnosis”提出了深度残差收缩网络。该研究通过在残差网络(Residual Network, ResNet)中引入可学习的软阈值算子,实现了特征层面的自适应降噪,为强干扰背景下的设备状态监测提供了新的解决方案。

一、自适应软阈值化

深度残差收缩网络(Deep Residual Shrinkage Network, DRSN)的核心贡献在于将信号处理中的软阈值化(Soft Thresholding)集成到了深度残差网络的非线性层中。DRSN 通过引入收缩单元,将接近于零的无用特征(通常由噪声引起)直接置零,而保留较大的有用特征。其数学表达非常直白:y = sign(x) * max(|x| - τ, 0),其中τ是阈值。不同于传统方法需要人工根据经验设定阈值,DRSN通过子网络根据输入信号的状态自动学习合适的阈值,实现了“一信号一阈值”的定制化过滤。

二、RSBU-CW单元详解

本次复现采用的是性能较优的“具有逐通道阈值的残差收缩构建单元”(Residual Shrinkage Building Unit with Channel-wise thresholds, RSBU-CW)。如图所示,该单元在标准残差块的基础上增加了阈值估计路径。

首先,输入特征图通过绝对值处理和全局平均池化(Global Average Pooling, GAP)将空间维度压缩为一维向量。然后,该向量通过两层全连接(Fully Connected, FC)层进行特征压缩与重构,并经由Sigmoid激活函数输出一个缩放系数α。最后,阈值τ等于α乘以特征图绝对值的平均值。这种设计使得模型能够对不同的通道应用不同的阈值,充分考虑到振动信号在不同频段、不同传感器通道上的噪声分布差异


图1.png


三、数据集准备与代码实现

实验采用了经典的美国西储大学(Case Western Reserve University, CWRU)轴承数据集。如表所示,选取了10类典型状态,涵盖了正常状态(Normal)以及内圈(Inner Race)、滚动体(Ball)、外圈(Outer Race)在不同损伤尺寸(0.007、0.014、0.021 英寸)下的故障数据。


图2.png



复现代码基于TensorFlow 2.x框架编写,集成了数据增强技术,包括随机时域位移和注入加性高斯白噪声(Additive White Gaussian Noise, AWGN)。以下是完整的算法实现代码:

"""
本程序实现了Zhao等提出的深度残差收缩网络(DRSN)用于振动信号的故障诊断。
该方法通过在残差结构中集成可学习的软阈值收缩单元,
旨在提升模型在强干扰背景下的特征学习能力和分类精度。
代码集成了数据集加载、Z-Score归一化、模型编译、
数据增强以及在模拟工业噪声环境下的模型性能评估全流程。

参考文献:
Zhao M, Zhong S, Fu X, Tang B, Pecht M.
Deep residual shrinkage networks for fault diagnosis.
IEEE Transactions on Industrial Informatics, 2020, 16(7): 4681–4690.
"""

import os
import sys
import logging
import numpy as np
import scipy.io as sio
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from sklearn.model_selection import train_test_split

# =============================================================================
# 第一部分:运行环境与底层支撑配置
# =============================================================================

def setup_env():
    """
    初始化计算资源与硬件加速策略。
    通过环境变量控制日志等级,并配置物理显存优化方案。
    """
    logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
    
    # 抑制冗余的算子库日志输出
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    # 检测系统可用显卡并激活显存动态按需分配机制。
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for device in gpus:
                # 配置显存增长策略,防止模型初始化时独占物理显存。
                tf.config.experimental.set_memory_growth(device, True)
            logging.info("GPU 硬件加速已就绪,显存按需分配机制已激活。")
        except RuntimeError as err:
            logging.warning("尝试配置 GPU 显存策略时发生异常: %s", err)
    else:
        logging.info("未发现可用 GPU 设备,当前运算任务将回退至中央处理器(CPU)。")

# 立即执行环境配置
setup_env()

# =============================================================================
# 第二部分:数据加载组件
# =============================================================================

class CWRULoader:
    """
    振动信号数据加载器。
    实现对CWRU加速度信号的读取、序列切片及标签映射。
    """
    
    def __init__(self, data_dir, window_size=1024):
        """
        初始化集成器。
        :param data_dir: 数据仓库的根路径。
        :param window_size: 信号切片的长度(采样点数量)。
        """
        self.data_path = os.path.abspath(data_dir)
        self.window_size = window_size
        self.stride = window_size  # 采用无重叠的滑动步长

    def _load_mat(self, file_path):
        """
        解析 MATLAB 格式的振动记录文件。
        从Matlab结构体中检索驱动端(DE)振动时间序列数据。
        """
        try:
            raw_content = sio.loadmat(file_path)
            for identifier in raw_content.keys():
                if 'DE_time' in identifier:
                    return raw_content[identifier].flatten()
        except Exception:
            return None
        return None

    def load_data(self, fault_config_registry):
        """
        基于预定义的分类映射构建结构化的特征矩阵与标签向量。
        """
        x_collection = []
        y_collection = []
        is_accessible = False
        
        for class_id, filenames in fault_config_registry.items():
            for target_name in filenames:
                absolute_path = os.path.join(self.data_path, "{}.mat".format(target_name))
                if not os.path.exists(absolute_path):
                    continue
                
                signal = self._load_mat(absolute_path)
                if signal is None:
                    continue
                
                is_accessible = True
                # 执行信号序列的截断与窗口化处理
                max_offset = len(signal) - self.window_size + 1
                for offset in range(0, max_offset, self.stride):
                    sample = signal[offset : offset + self.window_size]
                    x_collection.append(sample)
                    y_collection.append(class_id)
        
        if not is_accessible:
            raise FileNotFoundError("指定的路径 '{}' 未包含任何可识别的数据源。".format(self.data_path))
            
        return np.array(x_collection, dtype='float32'), np.array(y_collection, dtype='int32')

def add_awgn(x_batch, snr):
    """
    根据功率谱密度计算并注入加性高斯白噪声(AWGN)。
    基于信号平均功率计算特定SNR的噪声方差并进行叠加。
    """
    x_batch = np.array(x_batch)
    rng = np.random.default_rng()
    
    # 处理固定或随机范围的信噪比
    snr_val = snr if isinstance(snr, (int, float)) else rng.uniform(snr[0], snr[1])
    
    # 基于 P_noise = P_signal / (10^(SNR/10)) 公式计算噪声功率
    signal_power = np.mean(np.square(x_batch), axis=1, keepdims=True)
    noise_power = signal_power / (10 ** (snr_val / 10.0))
    noise_vectors = rng.normal(0, np.sqrt(noise_power), x_batch.shape)
    
    return (x_batch + noise_vectors).astype('float32')


# =============================================================================
# 第三部分:深度残差收缩网络 (DRSN) 架构实现
# =============================================================================

class SoftThresholding(layers.Layer):
    """
    深度残差收缩网络中的核心“软阈值”层。
    实现非线性收缩函数:y = sign(x) * ReLU(|x| - threshold)。
    """
    def __init__(self, **kwargs):
        super(SoftThresholding, self).__init__(**kwargs)

    def call(self, inputs):
        """
        执行逐通道的自适应阈值过滤。
        """
        x, threshold = inputs
        # 将一维阈值向量扩展至与特征图对齐的维度
        threshold_expanded = tf.expand_dims(threshold, axis=1)
        return tf.sign(x) * tf.maximum(tf.abs(x) - threshold_expanded, 0.0)

class RSBU_CW(layers.Layer):
    """
    集成通道注意力机制的残差收缩单元 (RSBU-CW)。
    作为深度残差收缩网络的基本构建模块,具备特征提取、噪声估计与软阈值降噪的闭环能力。
    """
    def __init__(self, channels, kernel_size, strides=1, **kwargs):
        super(RSBU_CW, self).__init__(**kwargs)
        self.channels = channels
        self.strides = strides
        self.kernel_size = kernel_size
        self.weight_decay = regularizers.l2(1e-4)

        self.shortcut = None
        
        # 定义主干卷积路径
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.Activation('relu')
        self.conv1 = layers.Conv1D(channels, kernel_size, strides=strides, padding='same', 
                                  kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
        
        self.bn2 = layers.BatchNormalization()
        self.relu2 = layers.Activation('relu')
        self.conv2 = layers.Conv1D(channels, kernel_size, strides=1, padding='same', 
                                  kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
        
        # 阈值预测子网络架构 (Subnetwork for threshold)
        self.gap = layers.GlobalAveragePooling1D()
        self.fc1 = layers.Dense(channels, kernel_initializer='he_normal')
        self.bn_fc1 = layers.BatchNormalization()
        self.relu_fc1 = layers.Activation('relu')
        self.fc2 = layers.Dense(channels, activation='sigmoid') # Scaling parameter alpha
        self.soft_thresh = SoftThresholding()

    def build(self, input_shape):
        """
        根据输入维度动态调整恒等映射路径。
        """
        if self.strides != 1 or input_shape[-1] != self.channels:
            self.shortcut = models.Sequential([
                layers.Conv1D(self.channels, 1, strides=self.strides, padding='same'),
            ])
        super(RSBU_CW, self).build(input_shape)

    def call(self, inputs):
        """
        RSBU-CW 逻辑:特征变换 -> 统计量感知 -> 阈值预测 -> 降噪。
        """
        identity = inputs
        if self.shortcut:
            identity = self.shortcut(inputs)

        # 常规路径
        x = self.bn1(inputs)
        x = self.relu1(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv2(x)

        # 阈值设置路径
        abs_x = tf.abs(x)
        abs_mean = self.gap(abs_x)
        
        z = self.fc1(abs_mean)
        z = self.bn_fc1(z)
        z = self.relu_fc1(z)
        alpha = self.fc2(z) # Scaling parameter
        
        tau = tf.multiply(alpha, abs_mean) # Threshold
        
        # 应用降噪算子并完成跳跃连接
        filtered_x = self.soft_thresh([x, tau])
        return layers.Add()([filtered_x, identity])

class DRSN_CW(models.Model):
    """
    深度残差收缩网络分类器。
    堆叠多个RSBU-CW模块,在复杂噪声下提取稳健的故障特征。
    """
    def __init__(self, num_classes):
        super(DRSN_CW, self).__init__(name="DRSN_CW")
        
        # 初始特征映射层
        self.conv1 = layers.Conv1D(32, 15, strides=2, padding='same', kernel_initializer='he_normal')
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.Activation('relu')
        
        # 深度残差收缩层堆叠
        self.blocks = [
            RSBU_CW(32, 5, strides=2),
            RSBU_CW(32, 5, strides=1),
            RSBU_CW(64, 5, strides=2),
            RSBU_CW(64, 5, strides=1),
            RSBU_CW(128, 5, strides=2),
            RSBU_CW(128, 5, strides=1)
        ]
        
        # 输出分类模块
        self.bn_last = layers.BatchNormalization()
        self.relu_last = layers.Activation('relu')
        self.gap = layers.GlobalAveragePooling1D()
        self.fc_out = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        """
        全流程前向推理。
        """
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu1(x)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.bn_last(x)
        x = self.relu_last(x)
        x = self.gap(x)
        return self.fc_out(x)


# =============================================================================
# 第四部分:自动化诊断流水线与评估逻辑
# =============================================================================

def run_diagnosis_workflow(data_path, window_size=1024):
    """
    端到端训练评估:集成数据切分、标准化、训练与抗噪测试。
    """
    
    # 构建故障类别映射表
    fault_config_registry = {
        0: ['Normal_0', 'Normal_1', 'Normal_2', 'Normal_3'],
        1: ['IR007_0', 'IR007_1', 'IR007_2', 'IR007_3'],
        2: ['IR014_0', 'IR014_1', 'IR014_2', 'IR014_3'],
        3: ['IR021_0', 'IR021_1', 'IR021_2', 'IR021_3'],
        4: ['B007_0', 'B007_1', 'B007_2', 'B007_3'],
        5: ['B014_0', 'B014_1', 'B014_2', 'B014_3'],
        6: ['B021_0', 'B021_1', 'B021_2', 'B021_3'],
        7: ['OR007@6_0', 'OR007@6_1', 'OR007@6_2', 'OR007@6_3'],
        8: ['OR014@6_0', 'OR014@6_1', 'OR014@6_2', 'OR014@6_3'],
        9: ['OR021@6_0', 'OR021@6_1', 'OR021@6_2', 'OR021@6_3']
    }
    
    data_loader = CWRULoader(data_dir=data_path, window_size=window_size)
    
    try:
        x_raw, y_raw = data_loader.load_data(fault_config_registry)
    except Exception as failure:
        logging.error("数据准备环节发生严重错误: %s", failure)
        return

    # 数据集分层抽样划分
    x_train, x_temp, y_train, y_temp = train_test_split(x_raw, y_raw, test_size=0.3, random_state=42)
    x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5, random_state=42)
    
    # 统计标准化处理
    mean_val, std_val = np.mean(x_train), np.std(x_train)
    
    def normalize(data_array):
        """应用 Z-Score 归一化并调整张量秩"""
        return ((data_array - mean_val) / std_val).reshape(-1, window_size, 1)

    x_train = normalize(x_train)
    x_val = normalize(x_val)
    x_test = normalize(x_test)
    
    # 标签向量独热化处理
    num_classes = len(fault_config_registry)
    y_train = tf.keras.utils.to_categorical(y_train, num_classes).astype('float32')
    y_val = tf.keras.utils.to_categorical(y_val, num_classes).astype('float32')
    y_test = tf.keras.utils.to_categorical(y_test, num_classes).astype('float32')

    # 预置-8dB极低信噪比环境,用于评估模型在强噪声下的泛化力。
    x_val_noisy = add_awgn(x_val, snr=-8)
    x_test_noisy = add_awgn(x_test, snr=-8)

    def augment_data(x_feats, y_targets):
        """
        数据增强:通过循环移位、随机脉冲冲击与动态SNR提升泛化。
        """
        rng = np.random.default_rng()
        x_aug = x_feats.copy()
        
        batch_n, seq_n, _ = x_aug.shape

        # 随机时域位移
        for i in range(batch_n):
            time_shift = rng.integers(0, seq_n)
            x_aug[i, :, 0] = np.roll(x_aug[i, :, 0], time_shift)

        # 概率触发模拟脉冲冲击
        if rng.random() > 0.9: 
            for i in range(batch_n):
                if rng.random() > 0.5: 
                    num_impacts = rng.integers(1, 3) 
                    indices = rng.integers(0, seq_n, num_impacts)
                    impulse_scale = np.std(x_aug[i]) * rng.uniform(1.5, 2.5) 
                    x_aug[i, indices, 0] += impulse_scale * rng.choice([-1, 1], size=num_impacts)

        # 随机信噪比扰动
        if rng.random() > 0.5: 
            x_aug = add_awgn(x_aug, snr=(-8, 8))

        return x_aug.astype(np.float32), y_targets.astype(np.float32)

    def _enforce_tensor_spec(f_tensor, l_tensor):
        """辅助编译器明确计算图中的张量形状规格"""
        f_tensor.set_shape([None, window_size, 1])
        l_tensor.set_shape([None, num_classes])
        return f_tensor, l_tensor

    # 封装高性能数据分发管道
    train_ds = tf.data.Dataset.from_tensor_slices((x_train.astype('float32'), y_train))
    train_ds = train_ds.shuffle(len(x_train)).batch(64)
    train_ds = train_ds.map(
        lambda x, y: tf.numpy_function(augment_data, [x, y], [tf.float32, tf.float32]),
        num_parallel_calls=tf.data.AUTOTUNE
    ).map(_enforce_tensor_spec).prefetch(tf.data.AUTOTUNE)

    # 深度模型编译
    model = DRSN_CW(num_classes=num_classes)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.0),
        metrics=['accuracy']
    )

    logging.info("诊断系统初始化成功。类别总数: {}, 输入维度: {}".format(num_classes, window_size))
    
    # 回调策略:监控验证集损失,实现学习率衰减与早停保护。
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=7, min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
    ]

    # 启动拟合流程
    model.fit(
        train_ds,
        epochs=100,
        validation_data=(x_val_noisy, y_val),
        callbacks=callbacks,
        verbose=2
    )

    # 评估模型在模拟极恶劣工况(-8dB SNR)下的故障分类精度
    evaluation_metrics = model.evaluate(x_test_noisy, y_test, verbose=0)
    print("\n[评估报告] 在 -8dB SNR 强干扰环境下,深度残差收缩网络的诊断准确率为: {:.2f}%".format(evaluation_metrics[1]*100))

# =============================================================================
# 第五部分:主程序入口点
# =============================================================================

if __name__ == "__main__":
    # 定义预置的数据检索路径
    DATA_PATH = os.path.join(os.getcwd(), 'data_path')
    
    if not os.path.exists(DATA_PATH):
        logging.info("未在默认位置发现数据集。")
        interactive_path = input("请输入 .mat 数据资源夹的完整物理路径: ").strip()
        if interactive_path:
            DATA_PATH = interactive_path
        else:
            logging.error("路径无效,诊断程序无法继续运行。")
            sys.exit(0)

    # 触发端到端诊断逻辑
    run_diagnosis_workflow(DATA_PATH, window_size=1024)

四、实验结果与工业价值分析

根据实验结果截图显示,即使在-8dB的信噪比(Signal-to-Noise Ratio, SNR)这种恶劣的模拟环境下,DRSN依然展现出了较高的鲁棒性,诊断准确率稳定在90%以上。DRSN不仅学习了如何“识别故障”,还学习了如何“过滤噪声”。这种端到端的学习方式有助于降低深度学习模型在实际工业场景中的落地门槛。


图3.png



论文标题: Deep residual shrinkage networks for fault diagnosis

出版期刊: IEEE Transactions on Industrial Informatics. 2020, 16(7): 4681-4690.

DOI: 10.1109/TII.2019.2943898


【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。