手写西瓜书bp神经网络 mnist10 c#版本

举报
UE5技术哥 发表于 2024/06/01 20:24:13 2024/06/01
【摘要】 本文根据西瓜书第五章中给出的公式编写,书中给出了全连接神经网络的实现逻辑,本文在此基础上编写了Mnist10手搓10个数字的案例,网上也有一些其他手搓的例子参考。demo并没有用UE而是使用unity进行编写,方便且易于查错。该案例仅作为学习,博主也只是业余时间自学一些机器学习知识,欢迎各位路过的大佬留下建议~测试效果:源码下载地址:https://download.csdn.net/dow...

本文根据西瓜书第五章中给出的公式编写,书中给出了全连接神经网络的实现逻辑,本文在此基础上编写了Mnist10手搓10个数字的案例,网上也有一些其他手搓的例子参考。demo并没有用UE而是使用unity进行编写,方便且易于查错。
该案例仅作为学习,博主也只是业余时间自学一些机器学习知识,欢迎各位路过的大佬留下建议~

测试效果:

q1.jpeg


源码下载地址:
https://download.csdn.net/download/grayrail/87802798

1.符号的意义


首先理顺西瓜书第五章中的各符号的意义:

w1.png


2.正向传播(ForwardPropagation)


西瓜书第五章直接讲了反向传播,所以在这之前简单讲一下正向传播。

q2.png


以上图为例,输入层的维度是[3],隐层的维度是[n,3],输出层的维度是[4,n],因此最终输出维度是[4]。输入层通常就是原始输入的信息,隐层用于超参数与中间环节计算,隐层的维度是n * m,m是输入层的数据自身维度,n可以理解为n种可能性(博主自己的理解),例如隐层的第二个维度是50,那么就是假设了50种可能性进行训练。

基于此,那么正向传播的流程如下:

  1. 初始化隐层的连接权重v,维度1是输入数据的长度,维度2是有多少种可能性,维度2可以自己填一个适合的值。
  2. 以隐层的长度q进行循环,对每个h维度下的集合进行点乘求和,然后减去阈值γ并传入激活函数,写入b集合。
  3. 以输出层的长度l进行循环,每个l维度下的存放着所有可能性的点乘结果,然后减去阈值θ并传入激活函数,写入yhats集合。
  4. 可以再对yhats加一个softmax操作,筛选集合中的最大值返回下标索引,即输出结果。


3.反向传播(BackPropagation)


反向传播的难点之一是链式求导,西瓜书中已经帮我们把求导过程写好了,这里我先讲tips,再梳理反向传播流程。

3.1 关于损失函数E

q3.png

书中对损失函数E只提了一次,后续操作中用yhat-y直接减去的值带入sigmoid的偏导数公式,看的有点让人糊涂。
后来经过查询ChatGTP和别人的一些文章,了解到直接求差就是1/2 MSE求偏导数之后的结果,如果不用1/2 MSE作为损失函数,就把yhat-y换成其他公式。

3.2 公式梳理

书中的公式有点乱,下面给出按照顺序的梳理图:

3.2.1 W对于E的偏导数

q4.jpeg


3.2.2 V对于E的偏导数

q5.jpeg

这一部分应该是精髓所在,不是非常了解,不多加评论。

3.2.3 流程梳理


基于此,那么反向传播的流程如下:

  1. 单独对g项求值并存入数组,用真实值y yy和预测值带入Sigmoid的导数公式
  2. 单独对e项求值并存入数组
  3. 计算阈值(bias偏置)θ θθ的delta量并赋值
  4. 计算w的delta量并赋值
  5. 计算阈值(bias偏置)γ γγ的delta量并赋值
  6. 计算v vv的delta量并赋值

3.3 优化器


后来看书时,发现神经网络除了正向传播和反向传播,还有2个比较重要的东西损失函数、优化器,损失函数计算了每个超参数的梯度,优化器决定如何应用这些梯度。

例如本文末尾给的c#代码就用到了动量,也是一种优化器,像Tensorflow一些demo中常用的Adam优化器就是RMSProp+动量的做法,大致指动量的幅度是动态变化的,甚至学习率中还加了一个超参数去控制。

了解RMSProp优化器可以看一下《动手学深度学习》7.6.1章节,有相关实现。

4.代码实现


以Mnist案例为例,该案例使用神经网络识别28x28像素内图片的0-9个手写数字,接下来给出C#版本的Mnist代码实现,脚本挂载后有3种模式:

q6.png


Draw Image Mode 用于绘制0-9个数字,绘制后点OnGUI中的按钮储存,默认储存在Assets目录内,需手动右键刷新Project面板,然后自行根据规则改名
User Mode 使用已经训练好的神经网络进行数字识别(没有做缓存的功能,需要手动先训练几次)
Train Mode 训练模式,DataPath中填入图片路径,图片格式首先取前缀,例如:3_04,表明这个图片真实值数字是3,是第4张备选图片
该案例在西瓜书的基础上又加入了momentum动量、softmax、Dropout、初始随机值范围修改(-1,1),softmax使用《深度学习入门 基于PYTHON的理论与实现》一书中提供的公式。经过一些轮次训练后的运行结果:

q7.png


注意:下述代码的训练集需要自己绘制大概20-30个数字,当然也可以用我之前上传的mnist csv版本训练集,第一列是0-9的具体数字,后面是28x28的像素数据:
https://download.csdn.net/download/grayrail/88159671

C#代码如下:

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;

public class TestMnist10 : MonoBehaviour
{
    public enum EMode { Train, DrawImage, User }

    const float kDropoutProb = 0.4f;

    /// <summary>
    /// d个输入神经元
    /// </summary>
    int d;

    /// <summary>
    /// q个隐层神经元
    /// </summary>
    int q;

    /// <summary>
    /// l个输出神经元
    /// </summary>
    int l;

    /// <summary>
    /// 输入层原始值
    /// </summary>
    float[] x;

    /// <summary>
    /// 输入层到隐层神经元的连接权
    /// </summary>
    float[][] v;

    /// <summary>
    /// 缓存上一次的v权值
    /// </summary>
    float[][] lastVMomentum;

    /// <summary>
    /// 隐层神经元到输出层神经元的连接权
    /// </summary>
    float[][] w;

    /// <summary>
    /// 缓存上一次的w权值
    /// </summary>
    float[][] lastWMomentum;

    float[] wDropout;

    /// <summary>
    /// 反向传播g项
    /// </summary>
    float[] g;

    /// <summary>
    /// 反向传播e项
    /// </summary>
    float[] e;

    /// <summary>
    /// 隐层接收到的输入(通常List长度是隐层长度)
    /// </summary>
    List<float> b;

    /// <summary>
    /// 输出层接收到的输入(通常List长度是输出层长度)
    /// </summary>
    List<float> yhats;

    /// <summary>
    /// 输出层神经元的阈值
    /// </summary>
    float[] theta;

    /// <summary>
    /// 隐层神经元的阈值
    /// </summary>
    float[] gamma;


    public void Init(int inputLayerCount, int hiddenLayerCount, int outputLayerCount)
    {
        d = inputLayerCount;
        q = hiddenLayerCount;
        l = outputLayerCount;

        x = new float[inputLayerCount];
        b = new List<float>(1024);
        yhats = new List<float>(1024);

        e = new float[hiddenLayerCount];
        g = new float[outputLayerCount];

        v = GenDimsArray(typeof(float), new int[] { q, d }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
        w = GenDimsArray(typeof(float), new int[] { l, q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
        wDropout = GenDimsArray(typeof(float), new int[] { l }, 0, null) as float[];

        lastVMomentum = GenDimsArray(typeof(float), new int[] { q, d }, 0, null) as float[][];
        lastWMomentum = GenDimsArray(typeof(float), new int[] { l, q }, 0, null) as float[][];

        theta = GenDimsArray(typeof(float), new int[] { l }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
        gamma = GenDimsArray(typeof(float), new int[] { q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
    }
    public void ForwardPropagation(float[] input, out int output)
    {
        x = input;

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
            var r = UnityEngine.Random.value < kDropoutProb ? 1f : 0f;
            wDropout[jIndex] = r;
        }

        b.Clear();
        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
            var sum = 0f;
            for (int iIndex = 0; iIndex < d; ++iIndex)
            {
                var u = input[iIndex] * v[hIndex][iIndex];
                sum += u;
            }
            var alpha = sum - gamma[hIndex];

            var r = Sigmoid(alpha);

            b.Add(r);
        }

        yhats.Clear();
        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
            var sum = 0f;
            for (int hIndex = 0; hIndex < q; ++hIndex)
            {
                var u = b[hIndex] * w[jIndex][hIndex];
                sum += u;
            }
            var beta = sum - theta[jIndex];

            var r = Sigmoid(beta);

            //实际使用时关闭Dropout,训练时打开
            if (_EnableDropout)
            {
                r *= wDropout[jIndex];
                r /= kDropoutProb;
            }

            yhats.Add(r);
        }

        var softmaxResult = Softmax(yhats.ToArray());
        for (int i = 0; i < yhats.Count; i++)
        {
            yhats[i] = softmaxResult[i];
        }

        output = ArgMax(yhats);
    }
    public void BackPropagation(float[] correct)
    {
        const float kEta1 = 0.03f;
        const float kEta2 = 0.01f;

        const float kMomentum = 0.3f;

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
            var yhat = this.yhats[jIndex];
            var y = correct[jIndex];
            g[jIndex] = yhat * (1f - yhat) * (y - yhat);
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
            var bh = b[hIndex];
            var sum = 0f;
            //这个for循环的内容就是通过矩阵将梯度反向传递。
            for (int jIndex = 0; jIndex < l; ++jIndex)
                sum += w[jIndex][hIndex] * g[jIndex];
            e[hIndex] = bh * (1f - bh) * sum;
        }

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
            theta[jIndex] += -kEta1 * g[jIndex];
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
            for (int jIndex = 0; jIndex < l; ++jIndex)
            {
                var bh = b[hIndex];
                var delta = kMomentum * lastWMomentum[jIndex][hIndex] + kEta1 * g[jIndex] * bh;

                //实际使用时关闭Dropout,训练时打开
                if (_EnableDropout)
                {
                    var dropout = wDropout[jIndex];
                    delta *= dropout;
                    delta /= kDropoutProb;
                }

                w[jIndex][hIndex] += delta;
                lastWMomentum[jIndex][hIndex] = delta;
            }
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
            gamma[hIndex] += -kEta2 * e[hIndex];
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
            for (int iIndex = 0; iIndex < d; ++iIndex)
            {
                var delta = kMomentum * lastVMomentum[hIndex][iIndex] + kEta2 * e[hIndex] * x[iIndex];

                v[hIndex][iIndex] += delta;
                lastVMomentum[hIndex][iIndex] = delta;
            }
        }
    }
    void Start()
    {
        Init(784, 64, 10);
    }

    EMode _Mode;
    int[] _DrawNumberImage;
    bool _EnableDropout;
    string _DataPath;

    float Sigmoid(float val)
    {
        return 1f / (1f + Mathf.Exp(-val));
    }
    float[] Softmax(float[] inputs)
    {
        float[] outputs = new float[inputs.Length];
        float maxInput = inputs.Max();

        for (int i = 0; i < inputs.Length; i++)
        {
            outputs[i] = Mathf.Exp(inputs[i] - maxInput);
        }

        float expSum = outputs.Sum();
        for (int i = 0; i < outputs.Length; i++)
        {
            outputs[i] /= expSum;
        }

        return outputs;
    }
    float[] GetOneHot(string input)
    {
        if (input.StartsWith("0"))
            return new float[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("1"))
            return new float[] { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("2"))
            return new float[] { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("3"))
            return new float[] { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("4"))
            return new float[] { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };
        if (input.StartsWith("5"))
            return new float[] { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 };
        if (input.StartsWith("6"))
            return new float[] { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 };
        if (input.StartsWith("7"))
            return new float[] { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 };
        if (input.StartsWith("8"))
            return new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };
        else
            return new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
    }
    int ArgMax(List<float> yhats)
    {
        int index = 0;
        float maxValue = yhats[0];
        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
            if (yhats[jIndex] > maxValue)
            {
                maxValue = yhats[jIndex];
                index = jIndex;
            }
        }
        return index;
    }
    void Shuffle<T>(List<T> cardList)
    {
        int tempIndex = 0;
        T temp = default;
        for (int i = 0; i < cardList.Count; ++i)
        {
            tempIndex = UnityEngine.Random.Range(0, cardList.Count);
            temp = cardList[tempIndex];
            cardList[tempIndex] = cardList[i];
            cardList[i] = temp;
        }
    }
    /// <summary>
    /// 快速得到多维数组
    /// </summary>
    Array GenDimsArray(Type type, int[] dims, int deepIndex, Func<object> initFunc = null)
    {
        if (deepIndex < dims.Length - 1)
        {
            var sub_template = GenDimsArray(type, dims, deepIndex + 1, null);
            var current = Array.CreateInstance(sub_template.GetType(), dims[deepIndex]);

            for (int i = 0; i < dims[deepIndex]; ++i)
            {
                var sub = GenDimsArray(type, dims, deepIndex + 1, initFunc);
                current.SetValue(sub, i);
            }

            return current;
        }
        else
        {
            var arr = Array.CreateInstance(type, dims[deepIndex]);
            if (initFunc != null)
            {
                for (int i = 0; i < arr.Length; ++i)
                    arr.SetValue(initFunc(), i);
            }
            return arr;
        }
    }
    void OnGUI()
    {
        if (_DrawNumberImage == null)
            _DrawNumberImage = new int[784];

        GUILayout.BeginHorizontal();
        if (GUILayout.Button("Draw Image Mode"))
        {
            _Mode = EMode.DrawImage;

            Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
        }
        if (GUILayout.Button("User Mode"))
        {
            _Mode = EMode.User;

            Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
        }
        if (GUILayout.Button("Train Mode"))
        {
            _Mode = EMode.Train;
            _DataPath = Directory.GetCurrentDirectory() + "/TrainData";
        }
        GUILayout.EndHorizontal();

        var lastRect = GUILayoutUtility.GetLastRect();

        switch (_Mode)
        {
            case EMode.Train:
                {
                    GUILayout.BeginHorizontal();
                    GUILayout.Label("Data Path: ");
                    _DataPath = GUILayout.TextField(_DataPath);
                    GUILayout.EndHorizontal();

                    _EnableDropout = GUILayout.Button("dropout(" + (_EnableDropout ? "True" : "False") + ")")
                        ? !_EnableDropout : _EnableDropout;

                    if (GUILayout.Button("Train 10"))
                    {
                        var files = Directory.GetFiles(_DataPath);
                        List<(string, float[])> datas = new(512);
                        for (int i = 0; i < files.Length; ++i)
                        {
                            var strArr = File.ReadAllText(files[i]).Split(',');
                            datas.Add((Path.GetFileNameWithoutExtension(files[i]), Array.ConvertAll(strArr, m => float.Parse(m))));
                        }

                        for (int s = 0; s < 10; ++s)
                        {
                            Shuffle(datas);

                            for (int i = 0; i < datas.Count; ++i)
                            {
                                ForwardPropagation(datas[i].Item2, out int output);
                                UnityEngine.Debug.Log("<color=#00ff00> Input Number: " + datas[i].Item1 + " output: " + output + "</color>");
                                BackPropagation(GetOneHot(datas[i].Item1));
                                //break;
                            }
                        }
                    }
                }
                break;
            case EMode.DrawImage:
                {
                    lastRect.y += 50f;
                    var size = 20f;
                    var spacing = 2f;
                    var mousePosition = Event.current.mousePosition;
                    var mouseLeftIsPress = Input.GetMouseButton(0);
                    var mouseRightIsPress = Input.GetMouseButton(1);
                    var containSpacingSize = size + spacing;

                    for (int y = 0, i = 0; y < 28; ++y)
                    {
                        for (int x = 0; x < 28; ++x)
                        {
                            var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
                            GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);

                            if (rect.Contains(mousePosition))
                            {
                                if (mouseLeftIsPress)
                                    _DrawNumberImage[i] = 1;
                                else if (mouseRightIsPress)
                                    _DrawNumberImage[i] = 0;
                            }

                            ++i;
                        }
                    }
                    if (GUILayout.Button("Save"))
                    {
                        File.WriteAllText(Directory.GetCurrentDirectory() + "/Assets/tmp.txt", string.Join(",", _DrawNumberImage));
                    }
                }
                break;
            case EMode.User:
                {
                    lastRect.y += 150f;
                    var size = 20f;
                    var spacing = 2f;
                    var mousePosition = Event.current.mousePosition;
                    var mouseLeftIsPress = Input.GetMouseButton(0);
                    var mouseRightIsPress = Input.GetMouseButton(1);
                    var containSpacingSize = size + spacing;

                    for (int y = 0, i = 0; y < 28; ++y)
                    {
                        for (int x = 0; x < 28; ++x)
                        {
                            var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
                            GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);

                            if (rect.Contains(mousePosition))
                            {
                                if (mouseLeftIsPress)
                                    _DrawNumberImage[i] = 1;
                                else if (mouseRightIsPress)
                                    _DrawNumberImage[i] = 0;
                            }

                            ++i;
                        }
                    }
                    if (GUILayout.Button("Recognize"))
                    {
                        ForwardPropagation(Array.ConvertAll(_DrawNumberImage, m => (float)m), out int output);
                        Debug.Log("output: " + output);
                    }
                    break;
                }
        }
    }
}



参考文章

Java实现BP神经网络MNIST手写数字识别https://www.cnblogs.com/baby7/p/java_bp_neural_network_number_identification.html
反向传播算法对照 https://zhuanlan.zhihu.com/p/605765790

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。