字符序列检测——CRNN模型

举报
HWCloudAI 发表于 2022/12/05 14:50:40 2022/12/05
【摘要】 字符序列检测——CRNN模型在本案例中,我们将学习深度学习中的OCR(Optical Character Recognition)光学字符识别技术。OCR作为计算机视觉中较早使用深度学习技术的领域,有很多优秀的模型出现,所以通过此案例我们来学习深度学习下的OCR技术。普遍的深度学习下的OCR技术将文字识别过程分为:文本区域检测以及字符识别。本案例中介绍的模型CRNN就是一种字符识别模型,它...

字符序列检测——CRNN模型

在本案例中,我们将学习深度学习中的OCR(Optical Character Recognition)光学字符识别技术。OCR作为计算机视觉中较早使用深度学习技术的领域,有很多优秀的模型出现,所以通过此案例我们来学习深度学习下的OCR技术。普遍的深度学习下的OCR技术将文字识别过程分为:文本区域检测以及字符识别。本案例中介绍的模型CRNN就是一种字符识别模型,它将文字图片中的文字识别出来。

注意事项:

  1. 本案例使用框架**:** TensorFlow-1.13.1

  2. 本案例使用硬件规格**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB

  3. 进入运行环境方法:点此链接进入AI Gallery,点击Run in ModelArts按钮进入ModelArts运行环境,如需使用GPU,您可以在ModelArts JupyterLab运行界面右边的工作区进行切换

  4. 运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

  5. JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》

  6. 碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.数据和代码下载

运行下面代码,进行数据和代码的下载和解压

from modelarts.session import Session
sess = Session()

if sess.region_name == 'cn-north-1':
    bucket_path="modelarts-labs/notebook/DL_ocr_crnn_sequence_recognition/crnn.tar"
elif sess.region_name == 'cn-north-4':
    bucket_path="modelarts-labs-bj4/notebook/DL_ocr_crnn_sequence_recognition/crnn.tar"
else:
    print("请更换地区到北京一或北京四")

sess.download_data(bucket_path=bucket_path, path="./crnn.tar")
Successfully download file modelarts-labs-bj4/notebook/DL_ocr_crnn_sequence_recognition/crnn.tar from OBS to local ./crnn.tar

2.解压文件

!tar -xf crnn.tar
!pip install torch==1.3.0
!pip install torchvision==0.4.1
!pip install keras==2.1.6
!pip install keras_applications==1.0.5
!pip install opencv-python==4.1.0.25
Collecting torch==1.3.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/ae/05/50a05de5337f7a924bb8bd70c6936230642233e424d6a9747ef1cfbde353/torch-1.3.0-cp36-cp36m-manylinux1_x86_64.whl (773.1MB)

    99% |████████████�███████████████████| 772.6MB 127.8MB/s eta 0:00:01                             | 6.8MB 110.1MB/s eta 0:00:07.2MB/s eta 0:00:07            | 18.6MB 113.5MB/s eta 0:00:07            | 23.4MB 114.2MB/s eta 0:00:07.0MB/s eta 0:00:07.2MB/s eta 0:00:07.0MB/s eta 0:00:07       | 46.1MB 114.6MB/s eta 0:00:07 |██▏                             | 52.5MB 117.3MB/s eta 0:00:07��█▍                             | 58.0MB 111.3MB/s eta 0:00:07                     | 63.3MB 91.7MB/s eta 0:00:08 |██▉                             | 69.0MB 116.3MB/s eta 0:00:07                             | 74.4MB 105.4MB/s eta 0:00:07�█▎                            | 79.6MB 120.6MB/s eta 0:00:06�█▌                            | 85.3MB 126.2MB/s eta 0:00:06                      | 91.3MB 100.6MB/s eta 0:00:07�██                            | 96.8MB 124.5MB/s eta 0:00:06                   | 102.2MB 124.0MB/s eta 0:00:06                   | 107.2MB 119.4MB/s eta 0:00:06�                           | 113.2MB 135.5MB/s eta 0:00:0518.6MB 117.5MB/s eta 0:00:06/s eta 0:00:07/s eta 0:00:06       | 135.0MB 117.7MB/s eta 0:00:06MB/s eta 0:00:06/s eta 0:00:06�▎                         | 151.2MB 116.0MB/s eta 0:00:06         | 156.5MB 127.8MB/s eta 0:00:05        | 161.7MB 104.0MB/s eta 0:00:06▉                         | 166.3MB 120.1MB/s eta 0:00:06a 0:00:06   | 174.8MB 68.8MB/s eta 0:00:09 125.7MB/s eta 0:00:05      | 186.6MB 111.3MB/s eta 0:00:06eta 0:00:1096.0MB 137.6MB/s eta 0:00:05�▎                       | 200.5MB 102.4MB/s eta 0:00:06                       | 205.0MB 104.0MB/s eta 0:00:06 | 213.9MB 121.9MB/s eta 0:00:05 28% |█████████                       | 217.7MB 119.4MB/s eta 0:00:05:00:06███▌                      | 230.5MB 49.0MB/s eta 0:00:12        | 234.2MB 112.2MB/s eta 0:00:0500:06�██████▏                     | 246.6MB 116.9MB/s eta 0:00:05MB/s eta 0:00:05                 | 254.5MB 115.2MB/s eta 0:00:05                 | 258.5MB 139.4MB/s eta 0:00:04�█████▉                     | 262.1MB 116.5MB/s eta 0:00:05��████                     | 266.7MB 135.7MB/s eta 0:00:04��████████▏                    | 270.5MB 119.1MB/s eta 0:00:05% |███████████▍                    | 274.4MB 113.1MB/s eta 0:00:05% |███████████▌                    | 278.1MB 112.1MB/s eta 0:00:05��████████▊                    | 281.9MB 126.5MB/s eta 0:00:04% |███████████▉                    | 285.3MB 135.1MB/s eta 0:00:04��█████████                    | 289.0MB 112.9MB/s eta 0:00:05/s eta 0:00:05 296.5MB 105.7MB/s eta 0:00:05    38% |████████████▍                   | 300.4MB 119.1MB/s eta 0:00:04      | 303.9MB 138.5MB/s eta 0:00:047.3MB 60.9MB/s eta 0:00:080.6MB 56.1MB/s eta 0:00:09��██████                   | 315.2MB 118.2MB/s eta 0:00:04██████▌                  | 325.3MB 120.1MB/s eta 0:00:04█████▋                  | 329.1MB 48.6MB/s eta 0:00:10               | 332.8MB 121.6MB/s eta 0:00:04��████████                  | 336.6MB 106.7MB/s eta 0:00:05███████                  | 340.1MB 112.8MB/s eta 0:00:040:05    45% |██████████████▌                 | 350.5MB 67.4MB/s eta 0:00:070:04                 | 357.5MB 120.3MB/s eta 0:00:04��██████████████▏                | 367.7MB 97.9MB/s eta 0:00:05�██████████████▍                | 371.2MB 115.7MB/s eta 0:00:04           | 374.6MB 56.8MB/s eta 0:00:08     | 384.6MB 131.7MB/s eta 0:00:03           | 388.1MB 116.6MB/s eta 0:00:04�██████████▏               | 391.4MB 122.5MB/s eta 0:00:04 395.2MB 112.8MB/s eta 0:00:04�██████████▌               | 398.0MB 111.5MB/s eta 0:00:04�▊               | 404.9MB 112.2MB/s eta 0:00:04�███████████               | 408.4MB 120.2MB/s eta 0:00:04�███               | 411.8MB 59.5MB/s eta 0:00:07        | 415.0MB 111.3MB/s eta 0:00:04417.9MB 57.1MB/s eta 0:00:07.9MB/s eta 0:00:06             | 424.6MB 136.4MB/s eta 0:00:03:00:03      | 430.7MB 132.1MB/s eta 0:00:03:00:047.1MB 136.1MB/s eta 0:00:03           | 440.2MB 111.3MB/s eta 0:00:03��███████████████▍             | 442.9MB 112.7MB/s eta 0:00:03�█████████▌             | 445.8MB 55.4MB/s eta 0:00:06           | 449.0MB 117.3MB/s eta 0:00:03████▊             | 452.1MB 118.1MB/s eta 0:00:03 eta 0:00:03 127.4MB/s eta 0:00:03           | 461.5MB 115.0MB/s eta 0:00:03��██████████▎            | 464.9MB 111.2MB/s eta 0:00:03▍            | 468.1MB 56.7MB/s eta 0:00:06�██████████████████▌            | 471.4MB 52.2MB/s eta 0:00:06��██████████▋            | 474.7MB 131.4MB/s eta 0:00:03��██████████▉            | 478.1MB 114.8MB/s eta 0:00:03�██            | 481.3MB 102.1MB/s eta 0:00:03          | 484.6MB 107.4MB/s eta 0:00:0303�████████████████▌           | 493.9MB 130.5MB/s eta 0:00:03 0:00:03    64% |████████████████████▊           | 500.4MB 133.1MB/s eta 0:00:03 0:00:02    65% |█████████████████████           | 506.5MB 118.2MB/s eta 0:00:03    65% |█████████████████████           | 509.7MB 118.7MB/s eta 0:00:03�█████▎          | 513.0MB 111.8MB/s eta 0:00:03�          | 515.6MB 121.3MB/s eta 0:00:03 518.4MB 58.9MB/s eta 0:00:05�███████▋          | 522.2MB 124.8MB/s eta 0:00:03�█████▊          | 525.4MB 128.6MB/s eta 0:00:02�█████▉          | 528.6MB 112.6MB/s eta 0:00:03 531.5MB 121.6MB/s eta 0:00:02��█████████████████▏         | 535.3MB 128.1MB/s eta 0:00:02███████▎         | 538.4MB 138.6MB/s eta 0:00:02███████▍         | 541.8MB 105.1MB/s eta 0:00:03███████▌         | 544.7MB 117.1MB/s eta 0:00:02███████▊         | 548.0MB 117.7MB/s eta 0:00:02s eta 0:00:04MB/s eta 0:00:02�██████         | 556.6MB 67.8MB/s eta 0:00:04��██████████████████▏        | 559.1MB 114.4MB/s eta 0:00:02��██████████████████▍        | 565.4MB 119.7MB/s eta 0:00:02��██████████████████▌        | 568.0MB 129.6MB/s eta 0:00:02��█████████████████▋        | 571.4MB 50.2MB/s eta 0:00:05��████▊        | 574.0MB 134.9MB/s eta 0:00:02████████▉        | 576.7MB 131.7MB/s eta 0:00:02 eta 0:00:024:00:024.0MB 49.8MB/s eta 0:00:04█▉       | 599.6MB 129.9MB/s eta 0:00:02:00:02 eta 0:00:02██████████████▏      | 607.9MB 109.7MB/s eta 0:00:02��████████████▎      | 610.9MB 111.6MB/s eta 0:00:02�      | 613.4MB 138.5MB/s eta 0:00:02�███████████████▌      | 616.8MB 135.6MB/s eta 0:00:02�      | 620.2MB 136.4MB/s eta 0:00:02██▊      | 622.7MB 137.0MB/s eta 0:00:02     | 625.2MB 128.4MB/s eta 0:00:02  | 628.1MB 88.0MB/s eta 0:00:02   | 631.4MB 135.1MB/s eta 0:00:02█████████▎     | 634.2MB 52.0MB/s eta 0:00:03█████▍     | 636.9MB 71.9MB/s eta 0:00:02�▌     | 639.4MB 119.5MB/s eta 0:00:02   | 642.3MB 51.8MB/s eta 0:00:03MB/s eta 0:00:01ta 0:00:02B/s eta 0:00:02�███████████     | 653.4MB 123.5MB/s eta 0:00:01 | 656.4MB 54.0MB/s eta 0:00:03�████▍    | 662.0MB 135.9MB/s eta 0:00:01��██████████████▌    | 664.5MB 110.2MB/s eta 0:00:01| 667.3MB 132.4MB/s eta 0:00:01 | 669.9MB 117.8MB/s eta 0:00:01�█████████████████████▉    | 672.2MB 114.7MB/s eta 0:00:01�█████    | 677.6MB 124.5MB/s eta 0:00:01ta 0:00:01:02:00:01    89% |████████████████████████████▋   | 690.3MB 98.7MB/s eta 0:00:0125.5MB/s eta 0:00:01�███████████▏  | 705.7MB 122.8MB/s eta 0:00:01 708.5MB 116.7MB/s eta 0:00:01 710.7MB 94.9MB/s eta 0:00:01 713.0MB 135.7MB/s eta 0:00:01 715.3MB 130.3MB/s eta 0:00:01��██████▊  | 717.7MB 115.5MB/s eta 0:00:01 720.2MB 125.9MB/s eta 0:00:01 722.5MB 135.6MB/s eta 0:00:01 725.0MB 121.8MB/s eta 0:00:01 727.2MB 136.3MB/s eta 0:00:0194% |██████████████████████████████▏ | 729.6MB 139.6MB/s eta 0:00:0194% |██████████████████████████████▎ | 731.8MB 140.4MB/s eta 0:00:01��█████████████████████████████▍ | 733.8MB 122.8MB/s eta 0:00:01████████▌ | 735.5MB 131.0MB/s eta 0:00:01███████▌ | 738.0MB 92.4MB/s eta 0:00:01�████████▊ | 742.2MB 131.3MB/s eta 0:00:017% |███████████████████████████████ | 749.9MB 139.9MB/s eta 0:00:01�█████▏| 752.1MB 134.1MB/s eta 0:00:01�█████▎| 754.4MB 133.8MB/s eta 0:00:01��██████████████████▍| 758.7MB 134.3MB/s eta 0:00:01██████████████████████▌| 760.9MB 139.7MB/s eta 0:00:01��██████████████████▋| 763.5MB 108.8MB/s eta 0:00:01�████████████████▊| 765.8MB 117.2MB/s eta 0:00:01��█████████████████▊| 767.7MB 60.6MB/s eta 0:00:01ta 0:00:01    100% |████████████████████████████████| 773.1MB 16.8MB/s ta 0:00:011

[?25hRequirement already satisfied: numpy in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from torch==1.3.0)

Installing collected packages: torch

Successfully installed torch-1.3.0

You are using pip version 9.0.1, however version 20.3.3 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.

Collecting torchvision==0.4.1

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/fc/23/d418c9102d4054d19d57ccf0aca18b7c1c1f34cc0a136760b493f78ddb06/torchvision-0.4.1-cp36-cp36m-manylinux1_x86_64.whl (10.1MB)

    100% |████████████████████████████████| 10.1MB 124.6MB/s ta 0:00:01███████████████| 10.1MB 131.9MB/s eta 0:00:01

[?25hRequirement already satisfied: torch==1.3.0 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from torchvision==0.4.1)

Requirement already satisfied: numpy in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from torchvision==0.4.1)

Requirement already satisfied: pillow>=4.1.1 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from torchvision==0.4.1)

Requirement already satisfied: six in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from torchvision==0.4.1)

Installing collected packages: torchvision

Successfully installed torchvision-0.4.1

You are using pip version 9.0.1, however version 20.3.3 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.

Collecting keras==2.1.6

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/54/e8/eaff7a09349ae9bd40d3ebaf028b49f5e2392c771f294910f75bb608b241/Keras-2.1.6-py2.py3-none-any.whl (339kB)

    100% |████████████████████████████████| 348kB 20.2MB/s ta 0:00:01

[?25hRequirement already satisfied: numpy>=1.9.1 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras==2.1.6)

Requirement already satisfied: pyyaml in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras==2.1.6)

Requirement already satisfied: h5py in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras==2.1.6)

Requirement already satisfied: six>=1.9.0 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras==2.1.6)

Requirement already satisfied: scipy>=0.14 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras==2.1.6)

Installing collected packages: keras

  Found existing installation: Keras 2.2.4

    Uninstalling Keras-2.2.4:

      Successfully uninstalled Keras-2.2.4

Successfully installed keras-2.1.6

You are using pip version 9.0.1, however version 20.3.3 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.

Collecting keras_applications==1.0.5

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/3f/9c/6e9393ead970fd97be0cfde912697dafec5800d9191f5ba25352fa537d72/Keras_Applications-1.0.5-py2.py3-none-any.whl (44kB)

    100% |████████████████████████████████| 51kB 10.8MB/s a 0:00:011

[?25hRequirement already satisfied: numpy>=1.9.1 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras_applications==1.0.5)

Requirement already satisfied: h5py in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras_applications==1.0.5)

Requirement already satisfied: keras>=2.1.6 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras_applications==1.0.5)

Requirement already satisfied: six in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from h5py->keras_applications==1.0.5)

Requirement already satisfied: scipy>=0.14 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras>=2.1.6->keras_applications==1.0.5)

Requirement already satisfied: pyyaml in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from keras>=2.1.6->keras_applications==1.0.5)

Installing collected packages: keras-applications

  Found existing installation: Keras-Applications 1.0.8

    Uninstalling Keras-Applications-1.0.8:

      Successfully uninstalled Keras-Applications-1.0.8

Successfully installed keras-applications-1.0.5

You are using pip version 9.0.1, however version 20.3.3 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.

Collecting opencv-python==4.1.0.25

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/7b/d2/a2dbf83d4553ca6b3701d91d75e42fe50aea97acdc00652dca515749fb5d/opencv_python-4.1.0.25-cp36-cp36m-manylinux1_x86_64.whl (26.6MB)

    100% |████████████████████████████████| 26.6MB 115.4MB/s ta 0:00:01��█████████████████▉           | 17.3MB 107.5MB/s eta 0:00:01�█████▊     | 22.1MB 107.7MB/s eta 0:00:01

[?25hRequirement already satisfied: numpy>=1.11.3 in /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages (from opencv-python==4.1.0.25)

Installing collected packages: opencv-python

  Found existing installation: opencv-python 3.4.1.15

    Uninstalling opencv-python-3.4.1.15:

      Successfully uninstalled opencv-python-3.4.1.15

Successfully installed opencv-python-4.1.0.25

You are using pip version 9.0.1, however version 20.3.3 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.
from tensorflow import ConfigProto
from tensorflow import InteractiveSession 
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint16 = np.dtype([("qint16", np.int16, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint32 = np.dtype([("qint32", np.int32, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  np_resource = np.dtype([("resource", np.ubyte, 1)])

3.首先引用相关的库和文件

import numpy as np
import data.dataset as dataset
import keys as keys
import torch

from keras.layers import Flatten, BatchNormalization, Permute, TimeDistributed, Dense, Bidirectional, GRU
from keras.layers import Input, Conv2D, MaxPooling2D, ZeroPadding2D,Lambda
from keras.models import Model
from keras.optimizers import SGD
from keras import backend as K
Using TensorFlow backend.

函数ctc_lambda_func 将完成对ctc损失函数的计算。

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    y_pred = y_pred[:, 2:, :]
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

定义文字类型数

characters = keys.alphabet[:]
nclass=len(characters)+1

4.构建网络

input = Input(shape=(32, None, 1), name='the_input')
# CNN卷积层部分
m = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same', name='conv1')(input)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(m)
m = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same', name='conv2')(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool2')(m)
m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv3')(m)
m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv4')(m)

m = ZeroPadding2D(padding=(0, 1))(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool3')(m)

m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv5')(m)
m = BatchNormalization(axis=1)(m)
m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv6')(m)
m = BatchNormalization(axis=1)(m)
m = ZeroPadding2D(padding=(0, 1))(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool4')(m)
m = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid', name='conv7')(m)

m = Permute((2, 1, 3), name='permute')(m)
m = TimeDistributed(Flatten(), name='timedistrib')(m)
# RNN循环层部分
m = Bidirectional(GRU(256, return_sequences=True), name='blstm1')(m)
m = Dense(256, name='blstm1_out', activation='linear')(m)
m = Bidirectional(GRU(256, return_sequences=True), name='blstm2')(m)
y_pred = Dense(nclass, name='blstm2_out', activation='softmax')(m)

basemodel = Model(inputs=input, outputs=y_pred)
# 转录层部分
labels = Input(name='the_labels', shape=[None, ], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
# 模型输入
model = Model(inputs=[input, labels, input_length, label_length], outputs=[loss_out])
# 模型优化器
sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.

Instructions for updating:

Colocations handled automatically by placer.

WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3948: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Use tf.cast instead.

WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3928: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Use tf.cast instead.

5.model用于训练,加入了ctc损失计算部分

model.summary()

此处打印过长,省略输出

basemodel用于推理部分,包含了CNN+RNN层部分

basemodel.summary()
_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

the_input (InputLayer)       (None, 32, None, 1)       0         

_________________________________________________________________

conv1 (Conv2D)               (None, 32, None, 64)      640       

_________________________________________________________________

pool1 (MaxPooling2D)         (None, 16, None, 64)      0         

_________________________________________________________________

conv2 (Conv2D)               (None, 16, None, 128)     73856     

_________________________________________________________________

pool2 (MaxPooling2D)         (None, 8, None, 128)      0         

_________________________________________________________________

conv3 (Conv2D)               (None, 8, None, 256)      295168    

_________________________________________________________________

conv4 (Conv2D)               (None, 8, None, 256)      590080    

_________________________________________________________________

zero_padding2d_1 (ZeroPaddin (None, 8, None, 256)      0         

_________________________________________________________________

pool3 (MaxPooling2D)         (None, 4, None, 256)      0         

_________________________________________________________________

conv5 (Conv2D)               (None, 4, None, 512)      1180160   

_________________________________________________________________

batch_normalization_1 (Batch (None, 4, None, 512)      16        

_________________________________________________________________

conv6 (Conv2D)               (None, 4, None, 512)      2359808   

_________________________________________________________________

batch_normalization_2 (Batch (None, 4, None, 512)      16        

_________________________________________________________________

zero_padding2d_2 (ZeroPaddin (None, 4, None, 512)      0         

_________________________________________________________________

pool4 (MaxPooling2D)         (None, 2, None, 512)      0         

_________________________________________________________________

conv7 (Conv2D)               (None, 1, None, 512)      1049088   

_________________________________________________________________

permute (Permute)            (None, None, 1, 512)      0         

_________________________________________________________________

timedistrib (TimeDistributed (None, None, 512)         0         

_________________________________________________________________

blstm1 (Bidirectional)       (None, None, 512)         1181184   

_________________________________________________________________

blstm1_out (Dense)           (None, None, 256)         131328    

_________________________________________________________________

blstm2 (Bidirectional)       (None, None, 512)         787968    

_________________________________________________________________

blstm2_out (Dense)           (None, None, 5531)        2837403   

=================================================================

Total params: 10,486,715

Trainable params: 10,486,699

Non-trainable params: 16

_________________________________________________________________

函数one_hot对标签进行处理

def one_hot(text, length, characters=characters):
    label = np.zeros(length)
    for i, char in enumerate(text):
        index = characters.find(char)
        if index == -1:
            index = characters.find(u' ')
        label[i] = index
    return label

加载数据

trainroot = './data/'
# 读取lmdb
train_dataset = dataset.lmdbDataset(root=trainroot, target_transform=one_hot)
test_dataset = dataset.lmdbDataset(
    root=trainroot,
    transform=dataset.resizeNormalize((256, 32)),
    target_transform=one_hot)
# 加载数据
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    sampler=None,
    num_workers=4,
    collate_fn=dataset.alignCollate(
        imgH=32, imgW=256,))
test_loader = torch.utils.data.DataLoader(
    test_dataset, shuffle=True, batch_size=1, num_workers=4)
nSamples:738

nSamples:738

数据生成器

def gen(loader, flag='train'):
    while True:
        i = 0
        n = len(loader)
        for X, Y in loader:
            X = X.numpy()
            X = X.reshape((-1, 32, 256, 1))
            if flag == 'test':
                Y = Y.numpy()

            Y = np.array(Y)
            Length = int(256 / 4) - 1
            batchs = X.shape[0]
            if i > n - 1:
                i = 0
                break

            yield [
                X, Y,
                np.ones(batchs) * int(Length),
                np.ones(batchs) * int(len(Y))
            ], np.ones(batchs)
#加载预训练模型
modelPath = './model_crnn.h5'    
model.load_weights(modelPath)
model.fit_generator(
    gen(train_loader, flag='train'),
    steps_per_epoch=100,
    epochs=2,
    validation_data=gen(test_loader, flag='test'),
    validation_steps=10)
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Deprecated in favor of operator or tf.math.divide.

WARNING:tensorflow:Variable *= will be deprecated. Use `var.assign(var * other)` if you want assignment to the variable value or `x = x * y` if you want a new python Tensor object.


Epoch 1/2

100/100 [==============================] - 30s 297ms/step - loss: 42.6744 - val_loss: 46.7768

Epoch 2/2

100/100 [==============================] - 23s 230ms/step - loss: 29.9630 - val_loss: 44.9445





<keras.callbacks.History at 0x7fb2387f3c88>
model.save_weights('./CRNN_results.h5')

6.测试部分

basemodel.load_weights('./CRNN_results.h5')  # 加载刚训练好的模型

7.加载待识别的原图

from PIL import Image
from CRNN_model import decode
img = Image.open('./img_0.png')
print('待识别的原图:')
img
待识别的原图:

8.定义CRNN字符识别函数

def crnn_ocr(img):
    """
    CRNN字符识别函数
    :param img: 需要进行字符识别的图片
    :return: ocr_result: 图片的字符识别结果,数据类型为字符串
    """
    img = img.convert('L')  # 图片灰度化
    
    scale = img.size[1] * 1.0 / 32  # 图片尺寸调整,把图片高度调整为32
    w = img.size[0] / scale
    w = int(w)
    img = img.resize((w, 32))
    img = np.array(img).astype(np.float32) / 255.0
    X = img.reshape((32, w, 1))
    X = np.array([X])
    y_pred = basemodel.predict(X)  # 预测
    ocr_result = decode(y_pred)  # 处理预测结果
    
    return ocr_result

9.调用函数,获得字符识别结果

ocr_result = crnn_ocr(img)
print('字符识别结果:', ocr_result)
字符识别结果: 文字识别模型文字识别。
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。