重新思考神经网络的输入:DWT-CNN

举报
李长安 发表于 2023/02/16 16:22:16 2023/02/16
【摘要】 受量子物理中波粒二象性的启发,以及最近 复现的八度卷积的设计思想启发,在本项目中设计了DWT-CNN,将神经网络的输入进行信号分解,分别送入网络的不同分支,最后通过一个全连接层进行分类。目前,通过简单的实验发现,所提出的网络结构能够更快收敛,其他下游计算机视觉任务有待实验证明。

重新思考神经网络的输入:DWT-CNN

0、前言

  受量子物理中波粒二象性的启发,以及最近
复现的八度卷积的设计思想启发,在本项目中设计了DWT-CNN,将神经网络的输入进行信号分解,分别送入网络的不同分支,最后通过一个全连接层进行分类。目前,通过简单的实验发现,所提出的网络结构能够更快收敛,其他下游计算机视觉任务有待实验证明。

1、网络设计思想

  如上图所示,本项目提出了一种双分支网络,首先將图像信号进行离散小波变换,将图像信号拆解为两个函数,分别入对应的网络通路,网络整体设计思想参考了双线性卷积神经网络的结构。

  首先将图像送入到一个卷积层中,对该卷积层的输出进行离散小波变换操作,然后将分解出来的两组信号分别送入到不同的分支网络中,经过一系列卷积操作后送入到全连接层进行分类。

1.1 离散小波变换

  小波分解的意义就在于能够在不同尺度上对信号进行分解,而且对不同尺度的选择可以根据不同的目标来确定。对于许多信号,低频成分相当重要,它常常蕴含着信号的特征,而高频成分则给出信号的细节或差别。人的话音如果去掉高频成分,听起来与以前可能不同,但仍能知道所说的内容;如果去掉足够的低频成分,则听到的是一些没有意义的声音。在小波分析中经常用到近似与细节。近似表示信号的高尺度,即低频信息;细节表示信号的高尺度,即高频信息。因此,原始信号通过两个相互滤波器产生两个信号。

  • 参考资料

离散小波变换(DWT)

2、网络搭建与可视化

!pip install PyWavelets

import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import pywt
from paddle.nn import Linear, Dropout, ReLU
from paddle.nn import Conv2D, MaxPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr
from paddle.utils.download import get_weights_path_from_url

class MyNet_dwt(nn.Layer):

    def __init__(self, num_classes=1000):
        super(MyNet_dwt, self).__init__()
        self.num_classes = num_classes
        
        self._conv1 = Conv2D(
            3,
            128,
            3,
            stride=2,
            padding=1,
            )      

        self._conv2_1 = Conv2D(
            128,
            256,
            3,
            stride=2,
            padding=1,
            )  
        self._conv3_1 = Conv2D(
            256,
            512,
            3,
            stride=2,
            padding=1,
            )
        self._conv4_1 = Conv2D(
            512,
            256,
            3,
            stride=2,
            padding=1,
            )

        self._conv2_2 = Conv2D(
            128,
            256,
            3,
            stride=2,
            padding=1,
            )  
        self._conv3_2 = Conv2D(
            256,
            512,
            3,
            stride=2,
            padding=1,
            )
        
        self._conv4_2 = Conv2D(
            512,
            256,
            3,
            stride=2,
            padding=1,
            )

        self._fc8 = Linear(
            in_features=50176,
            out_features=num_classes,
            )

    def forward(self, inputs):
        x = self._conv1(inputs)
        x = paddle.to_tensor(pywt.dwt(x.numpy(), 'haar'), dtype='float32')
        x1,x2 = x.split(2)

        x1 = x1.squeeze(axis=0)
        x2 = x2.squeeze(axis=0)
        x1 = self._conv2_1(x1)
        x1 = self._conv3_1(x1)
        x1 = F.relu(x1)
        x1 = self._conv4_1(x1)
        x1 = F.relu(x1)

        x2 = self._conv2_2(x2)
        x2 = self._conv3_2(x2)
        x2 = F.relu(x2)
        x2 = self._conv4_2(x2)
        x2 = F.relu(x2)      
        
        x = paddle.concat(x = [x1,x2], axis=2)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self._fc8(x)
        return x

model_res = MyNet_dwt(num_classes=95)

paddle.summary(model_res,(1,3,224,224))

3 对比实验

  本项目中设置一组对比实验,实验数据为Cifar10,对比模型为AlexNet。分别迭代5轮,最后给出对比实验结果。

import paddle

from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize, Resize, Transpose, ToTensor

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir_alexnetdwt')

normalize = Normalize(mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    data_format='HWC')
transform = Compose([ToTensor(), Normalize(), Resize(size=(224,224))])

cifar10_train = paddle.vision.datasets.Cifar10(mode='train',
                                               transform=transform)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test',
                                              transform=transform)

# 构建训练集数据加载器
train_loader = paddle.io.DataLoader(cifar10_train, batch_size=256, shuffle=True, drop_last=True)

# 构建测试集数据加载器
test_loader = paddle.io.DataLoader(cifar10_test, batch_size=256, shuffle=True, drop_last=True)

dwt_net = paddle.Model(MyNet_dwt(num_classes=10))
optim = paddle.optimizer.Adam(learning_rate=3e-4, parameters=dwt_net.parameters())

dwt_net.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy()
    )

dwt_net.fit(train_data=train_loader,
        eval_data=test_loader,
        epochs=5,
        callbacks=callback,
        verbose=1
        )
from paddle.vision.models import AlexNet
import paddle

alexnet = AlexNet(num_classes=10)

paddle.summary(alexnet,(1,3,224,224))
import paddle

from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize, Resize, Transpose, ToTensor

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir_alexnet')

normalize = Normalize(mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    data_format='HWC')
transform = Compose([ToTensor(), Normalize(), Resize(size=(224,224))])

cifar10_train = paddle.vision.datasets.Cifar10(mode='train',
                                               transform=transform)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test',
                                              transform=transform)

# 构建训练集数据加载器
train_loader = paddle.io.DataLoader(cifar10_train, batch_size=768, shuffle=True, drop_last=True)

# 构建测试集数据加载器
test_loader = paddle.io.DataLoader(cifar10_test, batch_size=768, shuffle=True, drop_last=True)

alexnet = paddle.Model(AlexNet(num_classes=10))
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=alexnet.parameters())

alexnet.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy()
    )

alexnet.fit(train_data=train_loader,
        eval_data=test_loader,
        epochs=5,
        callbacks=callback,
        verbose=1
        )

4、实验结果

Model Train Acc Eval Acc
DWT-CNN 0.8369 0.6809
AlexNet 0.5286 0.5742

5、总结

  通过目前的简单实验对比,我们可以得出如下结论:

1、收敛更快,

2、训练速度相较于AlexNet来说更慢,

  总体来看,模型在精度方面应该会有不错的表现,但是在速度上面还有提升空间。在本项目中的分支结构信号分解大家也可以尝试一些其他的方式,比如快速傅立叶变换等,本项目中主要为大家提供一种网络设计思想,大家不要拘泥于信号分解的形式,可以多做尝试,没准能写一篇Paper。


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。