粤港澳大湾区(黄埔)国际算法算例大赛-古籍文档图像识别与分析算法比赛简介
一、粤港澳大湾区(黄埔)国际算法算例大赛-古籍文档图像识别与分析算法比赛简介
1.背景及意义
- 中国几千年辉煌的华夏文明,留下了海量的古籍文献资料,这些文字记录承载着丰富的历史信息和文化传承。为响应古籍文化遗产保护、古籍数字化与推广应用的国家战略需求,传承中华优秀传统文化,挖掘利用古籍文献中蕴含的丰富知识,古籍透彻数字化工作势在必行。
- 由于古籍文档图像的版式复杂、不同朝代的刻字书写风格差异大、古籍文字图像存在缺失、污渍、笔墨污染、模糊、印章噪声干扰、生僻字异体字繁多等技术挑战,古籍文档图像的识别及理解依然是一个极具挑战、远未解决的技术难题。
- 为解决我国海量古籍数字化难题,本竞赛旨在征集先进的人工智能算法,解决高精度古籍文字检测、文本行识别、端到端古籍识别技术难题,推动古籍OCR技术进步,为古籍数字化保护、整理和利用提供人工智能支撑方法,特此举办本次比赛。
图 1古籍文档示例
2.赛题描述
任务:古籍文档图像分析与识别
输入: 篇幅级别的古籍文档图片
输出: 利用文档图像物理及逻辑版面结构分析、文字检测、文字识别、文字阅读顺序理解等技术输出结构化的文本行坐标以及识别内容,其中各个文本的检测结果与识别内容按阅读顺序进行排列输出。模型仅输出正文的检测识别结果。忽略如版心、卷号等非结构化的内容。
码表说明 :
本次比赛提供码表(下载链接见链接: https://pan.baidu.com/s/16wUeSZ4JKD6f1Pj9ZhlKww 提取码: i53n ),其中包含了初赛训练集、验证集**(初赛A榜)及测试集(初赛B榜)中出现的字符类别。(注意:由于比赛设置了zero shot识别场景,训练集中出现的字符类别没有完全覆盖码表中的类别,目前公布的码表已完整覆盖初赛训练集及初赛A榜测试集的所有字符类别,初赛B榜码表可能会略有微调,后续将择机公布,请留意比赛官网通知。)**
初赛B榜码表公布:
下载链接见链接:https://pan.baidu.com/s/1gaNlKHk6lh5FxC2QP4UuDg
提取码:umzz (公布日期:2022年9月8日)
3.数据集说明
- **初赛数据集:**训练集、验证集与测试集各包括1000幅古籍文档图像(共3000张图像),数据选自四库全书、历代古籍善本、乾隆大藏经等多种古籍数据。任务仅考虑古籍文档的正文内容,忽略如版心、卷号等边框外的内容。
- **决赛数据集:**由于采取【**擂台赛】**的形式,除了主办方提供的原始初赛数据集以及决赛数据之外,决赛参赛队伍可申请成为擂主并提供各自的数据集供其他进入决赛的队伍进行训练和测试,提供的训练集不少于1000张,测试集不多于1000张,提供的数据集标注格式应与主办方提供的数据格式相同。
数据集标注格式:
每幅图像文本行文字及内容根据文本行阅读顺序进行标注,包含在一个单独的json文件。标注格式如下所示:
{
“image_name_1”, [{“points”: x1, y1, x2, y2, …, xn, yn, “transcription”: text},
{“points”: x1, y1, x2, y2, …, xn, yn, “transcription”: text},
…],
“image_name_2”, [{“points”: x1, y1, x2, y2, …, xn, yn, “transcription”: text},
{“points”: x1, y1, x2, y2, …, xn, yn, “transcription”: text},
…],
……
}
- x1, y1, x2, y2, …, xn, yn代表文本框的各个点。
- 对于四边形文本,n=4;数据集中存在少量不规则文本,对于这类标注,n=16(两条长边各8个点)。
- Text代表每个文本行的内容,模糊无法识别的字均标注为#。
- 其中文本行的检测与识别标签按照正确的阅读顺序给出。端到端识别内容按照阅读顺序进行标注,仅考虑文档的正文内容,忽略如版心、卷号等边框外的内容。
- 阅读顺序的编排如图2所示。
图2 端到端古籍文档图像结构化识别理解中的阅读顺序标注可视化
4.提交结果
【初赛A榜】:
- **提交格式:**测试图片同名的CSV文件的压缩包
- 提交内容:每张图片对应一个CSV文件,CSV文件中包含文本的检测框坐标以及对应的识别结果,并且这些文本都要按照预测得到的阅读顺序进行排列。
Csv文件内部格式如下:
x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_1
x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_2
…
x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_n
(其中xn, yn代表坐标,这些坐标按顺时针进行排列,transcription_n代表文本的识别内容)
- 提交样式示例:
链接:https://pan.baidu.com/s/1h9smrGBwfJ78IP3WUlkEYQ
提取码:suzi
- 提交次数: 每天1次
- 开始提交时间: 9月15日
二、数据集处理
1.解压数据集
# !unzip -qoa data/data167941/dataset.zip
2.数据查看
!head -n30 dataset/train/label.json
{
"image_0.jpg": [
{
"points": [
1286,
59,
1326,
59,
1331,
851,
1290,
851
],
"transcription": "\u53ef\ud878\udcce\u4e45\u4e4e\u820e\u5229\u5f17\u563f\u7136\u4e0d\u8345\u25cf\u4e94\u8eab\u5b50\u81ea\u601d\u89e7\u8131\u7121\u4e45\u8fd1\u6545\u9ed9\u5929\u66f0\u5982\u4f55\ud859\udcbf\ud85b\udf94\u5927\u667a"
},
{
"points": [
1249,
57,
1286,
59,
1298,
851,
1251,
851
],
"transcription": "\u800c\u563f\u25cb\u516d\u5929\u554f\ud86e\udc26\u4ee5\u8087\u66f0\u4e94\u767e\u82d0\u5b50\u4ec1\u8005\u4f55\u667a\u6075\u82d0\u4e00\u563f\u7136\u4f55\u8036\u8345\u66f0\u89e7\u8131\u8005\u65e0\ud86e\udc26\u8a00\u8aaa"
},
{
"points": [
3.数据格式转换
对PaddleOCR检测任务来说,数据集格式如下:
" 图像文件名 json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
故需要对数据格式进行转换。
import json
# 读取源label.json
f = open('dataset/train/label.json', 'r')
x = f.read()
y = json.loads(x)
f.close()
# 查看长度1000
print(len(y))
# 查看数据格式
print(y["image_0.jpg"])
# 查看该文件下信息
print(len(y["image_0.jpg"]))
1000
[{'points': [1286, 59, 1326, 59, 1331, 851, 1290, 851], 'transcription': '可𮃎久乎舎利弗嘿然不荅●五身子自思觧脱無久近故黙天曰如何𦒿𦾔大智'}, {'points': [1249, 57, 1286, 59, 1298, 851, 1251, 851], 'transcription': '而嘿○六天問𫠦以肇曰五百苐子仁者何智恵苐一嘿然何耶荅曰觧脱者无𫠦言說'}, {'points': [1213, 60, 1252, 60, 1252, 784, 1213, 784], 'transcription': '故吾扵是不知𫠦云○七身子已離三𭻃惑得心觧脱永绝言𢿘故言不知𫠦云'}, {'points': [1173, 62, 1214, 62, 1224, 845, 1183, 845], 'transcription': '天曰言説文字皆觧脱𢪷●八天辨不思議觧脱即文字也文三𬼘𥘉摽文字即觧脱'}, {'points': [1135, 61, 1179, 61, 1184, 848, 1140, 848], 'transcription': '肇曰舎利弗以言文為失故黙然无言謂順真觧未䏻悟黙齊𦤺觸𢪱无礙故'}, {'points': [1099, 59, 1143, 59, 1149, 848, 1106, 848], 'transcription': '天說䒭觧以暁其意𫠦以者何觧脱者不内不外不在兩間文字亦不内不外不在兩'}, {'points': [1069, 61, 1111, 61, 1110, 854, 1065, 852], 'transcription': '間是故舎利弗无離文字説觧脱也𬼘二觧𥼶𫠦以肇曰法之𫠦在極扵三⺀𠁅⺀求文字'}, {'points': [1022, 61, 1066, 61, 1066, 851, 1022, 851], 'transcription': '觧脱俱不可淂如何欲離文字别説觧脱乎𫠦以者何一𭃄諸法是觧脱相○三明'}, {'points': [984, 60, 1025, 60, 1021, 850, 980, 850], 'transcription': '諸法䒭觧肇曰万法雖殊无非觧𢪷豈文字之獨異也舎利弗言不復以離媱怒'}, {'points': [946, 60, 985, 60, 978, 850, 938, 850], 'transcription': '𪪧為觧脱乎○𬼘下二明𣂾不𣂾别文二𥘉問也肇曰二乘结𥁞為觧脱聞上䒭觧乖'}, {'points': [905, 59, 951, 59, 942, 850, 895, 850], 'transcription': '其夲趣故𦤺斯問天日仏為増上𢢔人説離媱怒癡為觧脱耳𠰥无上𢢔者佛説'}, {'points': [860, 63, 909, 63, 902, 849, 852, 849], 'transcription': '媱怒癡性即是觧脱二荅也増上𢢔者未淂謂淂也身子𢴃小乘𫠦證非増上𢢔'}, {'points': [822, 62, 865, 62, 862, 850, 819, 850], 'transcription': '自謂共佛同㘴觧脱床名増上𢢔也既未悟缚解平䒭故為説離缚為觧𠰥大士'}, {'points': [779, 63, 822, 63, 822, 848, 779, 848], 'transcription': '非増上𢢔者為説即縛性脱性脱入不二門也舎利弗言善⺀㦲⺀天女汝何𫠦淂以何'}, {'points': [735, 62, 782, 62, 781, 846, 734, 846], 'transcription': '為證辨乃如是○三明證不證別文二𬼘𥘉也肇曰善其𫠦説非已𫠦及故問淂何道證'}, {'points': [693, 60, 736, 60, 745, 848, 703, 848], 'transcription': '阿果辨乃如是乎天曰我无淂无證故辨如是○荅文二𬼘𥘉正荅二乘捨缚求脱'}, {'points': [650, 62, 696, 62, 709, 852, 662, 852], 'transcription': '故有淂證大士悟縛脱平䒭非縛非脱故无淂无證既智窮不二之門故辨無礙'}, {'points': [619, 61, 658, 61, 664, 850, 626, 850], 'transcription': '也𫠦以者何𠰥有淂有證者則扵仏法為増上𢢔○二反厈肇曰𠰥見己有淂必見他'}, {'points': [576, 62, 617, 62, 631, 850, 591, 850], 'transcription': '不淂𬼘扵佛平䒭之法猶為増上𢢔人何䏻𦤺无礙之辨乎舎利弗問天汝扵三'}, {'points': [539, 63, 579, 63, 588, 845, 548, 845], 'transcription': '乘為何𢖽求𬼘下三约教明乘无乘別也小乘有法執故有差别乘大乘不二平'}, {'points': [497, 63, 539, 63, 550, 849, 508, 849], 'transcription': '䒭故无乘之乘文二𬼘𥘉問也肇曰上云无淂无證未知何乘故𣸪問也天曰以聲'}, {'points': [459, 63, 502, 63, 509, 853, 467, 853], 'transcription': '聞法化衆生故我為聲聞以因𦄘法化衆生故我為𮝻支仏以大悲法化衆生故我'}, {'points': [422, 65, 462, 65, 466, 851, 426, 851], 'transcription': '為大乘○二荅文二一惣约化𦄘荅二别约𫝆𦄘荅𬼘𥘉也肇曰大乘之道无乘之乘'}, {'points': [379, 65, 423, 65, 430, 827, 386, 827], 'transcription': '爲彼而乘吾何乘也生曰随彼為之我无㝎也又觧法花方便説三意同𬼘也'}, {'points': [342, 65, 382, 65, 396, 851, 356, 851], 'transcription': '舎利弗如人入瞻蔔林唯嗅瞻蔔不嗅餘香如是𠰥入𬼘室但聞仏㓛徳之香不樂'}, {'points': [300, 67, 343, 67, 360, 849, 318, 849], 'transcription': '聞聲聞𮝻支仏㓛徳香也○𬼘二约𫝆𦄘文四一明𫝆𦄘唯一二𫠦化樂大三室无小法四'}, {'points': [263, 64, 302, 64, 323, 849, 284, 849], 'transcription': '约室顕法𬼘𥘉也肇曰元乘不乘乃為大乘故以香林為喻明浄名之室不離二'}, {'points': [226, 64, 268, 64, 286, 849, 243, 849], 'transcription': '乘之香止乘止𬼘室者豈他嗅㢤舎利弗有其四𥼶梵四天王諸天龍神鬼√䒭入'}, {'points': [186, 63, 229, 63, 248, 855, 205, 855], 'transcription': '𬼘室者聞斯上人講说正法𣅜樂佛㓛徳之香𤼲心而出二明𫠦化皆樂大也舎利'}, {'points': [158, 65, 193, 63, 191, 204, 159, 207], 'transcription': '弗吾止𬼘室十'}, {'points': [183, 198, 200, 197, 200, 222, 183, 222], 'transcription': '有'}, {'points': [161, 207, 191, 205, 204, 856, 167, 859], 'transcription': '二年𥘉不聞説聲聞𮝻支仏法但聞菩薩大慈大悲不可思議諸'}, {'points': [121, 62, 169, 62, 172, 855, 125, 855], 'transcription': '佛之法三明深肇曰大乘之法𣅜不可思議上問止室久近欲生淪端故答'}, {'points': [80, 63, 122, 63, 131, 853, 90, 853], 'transcription': '以觧脱𫝆言實𭘾以明𫠦聞之不𮦀也生曰諸天鬼神蹔入室尚无不𤼲大意而出'}, {'points': [44, 62, 84, 62, 100, 849, 60, 849], 'transcription': '㦲况我久聞妙法乎然則不䏻不為大悲䏻為大矣舎利弗𬼘室常現八未曽有'}, {'points': [2, 60, 45, 60, 62, 848, 19, 848], 'transcription': '難淂之法𬼘四明未曽有室不说二乘之法也文三標𥼶结𬼘𥘉標也何謂為八'}]
36
# 格式转换
image_info_lists = {}
ff = open("dataset/train/label.txt", 'w')
for i in range(1000):
# print(f"image_{i}.jpg")
old_info = y[f"image_{i}.jpg"]
new_info = []
for item in old_info:
image_info = {}
image_info["transcription"] = item['transcription']
points = item["points"]
if len(points)==8:
image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
[points[6], points[7]]]
elif len(points)==32:
image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
[points[6], points[7]], [points[8], points[9]],[points[10], points[11]],
[points[12], points[13]], [points[14], points[15]],[points[16], points[17]],
[points[18], points[19]], [points[20], points[21]],[points[22], points[23]],
[points[24], points[25]], [points[26], points[27]],[points[28], points[29]],
[points[30], points[31]]]
elif len(points)==34:
image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
[points[6], points[7]], [points[8], points[9]],[points[10], points[11]],
[points[12], points[13]], [points[14], points[15]],[points[16], points[17]],
[points[18], points[19]], [points[20], points[21]],[points[22], points[23]],
[points[24], points[25]], [points[26], points[27]],[points[28], points[29]],
[points[30], points[31]],[points[32], points[33]]]
else:
continue
new_info.append(image_info)
image_info_lists[f"image_{i}.jpg"] = new_info
ff.write(f"image_{i}.jpg" + "\t" + json.dumps(new_info) + "\n")
ff.close()
# 查看数据
print(image_info_lists["image_0.jpg"][0])
{'transcription': '可𮃎久乎舎利弗嘿然不荅●五身子自思觧脱無久近故黙天曰如何𦒿𦾔大智', 'points': [[1286, 59], [1326, 59], [1331, 851], [1290, 851]]}
!head -n1 dataset/train/label.txt
4.分割数据集
前800为训练集
后200为测试集
%cd ~
!wc -l dataset/train/label.txt
/home/aistudio
1000 dataset/train/label.txt
!head -800 dataset/train/label.txt >dataset/train/train.txt
!tail -200 dataset/train/label.txt >dataset/train/eval.txt
三、PaddleOCR环境准备
1.PaddleOCR下载
# !git clone https://gitee.com/paddlepaddle/PaddleOCR.git --depth=1
2.PaddleOCR安装
%cd ~/PaddleOCR/
!python -m pip install -q -U pip --user
!pip install -q -r requirements.txt
/home/aistudio/PaddleOCR
# !mkdir pretrain_models/
# %cd pretrain_models
# !wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
# !tar -xvf ch_PP-OCRv3_det_distill_train.tar
四、模型训练
!pip list|grep opencv
opencv-contrib-python 4.6.0.66
opencv-python 4.2.0.32
1.opencv降级
opencv版本不对,需要降级,不然训练报以下错误。
Traceback (most recent call last):
File "tools/train.py", line 30, in <module>
from ppocr.data import build_dataloader
File "/home/aistudio/PaddleOCR/ppocr/data/__init__.py", line 35, in <module>
from ppocr.data.imaug import transform, create_operators
File "/home/aistudio/PaddleOCR/ppocr/data/imaug/__init__.py", line 19, in <module>
from .iaa_augment import IaaAugment
File "/home/aistudio/PaddleOCR/ppocr/data/imaug/iaa_augment.py", line 24, in <module>
import imgaug
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/imgaug/__init__.py", line 7, in <module>
from imgaug.imgaug import * # pylint: disable=redefined-builtin
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/imgaug/imgaug.py", line 18, in <module>
import cv2
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 181, in <module>
bootstrap()
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 175, in bootstrap
if __load_extra_py_code_for_module("cv2", submodule, DEBUG):
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 28, in __load_extra_py_code_for_module
py_module = importlib.import_module(module_name)
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/importlib/__init__.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/mat_wrapper/__init__.py", line 33, in <module>
cv._registerMatType(Mat)
AttributeError: module 'cv2' has no attribute '_registerMatType'
!pip uninstall opencv-python -y
!pip uninstall opencv-contrib-python -y
!pip install opencv-python==4.2.0.32
Found existing installation: opencv-python 4.2.0.32
Uninstalling opencv-python-4.2.0.32:
Successfully uninstalled opencv-python-4.2.0.32
Found existing installation: opencv-contrib-python 4.6.0.66
Uninstalling opencv-contrib-python-4.6.0.66:
Successfully uninstalled opencv-contrib-python-4.6.0.66
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting opencv-python==4.2.0.32
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/34/a3/403dbaef909fee9f9f6a8eaff51d44085a14e5bb1a1ff7257117d744986a/opencv_python-4.2.0.32-cp37-cp37m-manylinux1_x86_64.whl (28.2 MB)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from opencv-python==4.2.0.32) (1.19.5)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.2.0.32
2.训练配置
ch_PP-OCRv3_det_cml.yml
Global:
character_dict_path: ../mb.txt #自定义字典
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/ch_PP-OCR_v3_det/
save_epoch_step: 100
eval_batch_step:
- 0
- 400
cal_metric_during_train: false
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
name: DistillationModel
algorithm: Distillation
model_type: det
Models:
Student:
pretrained:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Student2:
pretrained:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Teacher:
pretrained:
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: ResNet_vd
in_channels: 3
layers: 50
Neck:
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
# 数据集
Train:
dataset:
name: SimpleDataSet
data_dir: /home/aistudio/dataset/train/image
label_file_list:
- /home/aistudio/dataset/train/label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste:
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 12
num_workers: 4
# 数据集
Eval:
dataset:
name: SimpleDataSet
data_dir: /home/aistudio/dataset/train/image
label_file_list:
- /home/aistudio/dataset/train/label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
# 拷贝配置到对应目录
!cp ~/ch_PP-OCRv3_det_cml.yml ~/PaddleOCR/configs/det/ch_PP-OCRv3/
%export CUDA_VISIBLE_DEVICES='0,1,2,3'
# !python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Optimizer.base_lr=0.0001
%cd ~/PaddleOCR/
!python3 -m paddle.distributed.launch --ips="localhost" --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Optimizer.base_lr=0.0001
[2022/09/13 20:07:57] ppocr INFO: epoch: [121/500], global_step: 2050, lr: 0.000869, dila_dbloss_Student_Teacher: 1.288603, dila_dbloss_Student2_Teacher: 1.283752, loss: 6.739372, dml_thrink_maps_0: 0.002895, db_Student_loss_shrink_maps: 1.410274, db_Student_loss_threshold_maps: 0.388946, db_Student_loss_binary_maps: 0.280807, db_Student2_loss_shrink_maps: 1.414650, db_Student2_loss_threshold_maps: 0.391234, db_Student2_loss_binary_maps: 0.281578, avg_reader_cost: 7.17047 s, avg_batch_cost: 11.48100 s, avg_samples: 12.0, ips: 1.04521 samples/s, eta: 20:44:22
五、识别数据集准备
把det的数据集转换为rec数据集,进行模型训练
# ppocr/utils/gen_label.py
# convert the official gt to rec_gt_label.txt
%cd ~/PaddleOCR
!python ppocr/utils/gen_label.py --mode="rec" --input_path="../dataset/train/train.txt" --output_label="../dataset/train/train_rec_gt_label.txt"
!python ppocr/utils/gen_label.py --mode="rec" --input_path="../dataset/train/eval.txt" --output_label="../dataset/train/eval_rec_gt_label.txt"
六、识别模型训练
1.预训练模型下载
%cd ~/PaddleOCR/pretrain_models
!https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
!tar -xvf ch_PP-OCRv3_rec_train.tar
2.配置训练参数
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v3_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
# 预训练模型
pretrained_model: pretrain_models/ch_PP-OCRv3_rec_train//best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
# 修改码表
character_dict_path: ../mb.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.0005, 0.00005]
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: ctc
name: dml_ctc
- DistillationDMLLoss:
weight: 0.5
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: sar
name: dml_sar
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
- DistillationSARLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
multi_head: True
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
ignore_space: False
# 修改数据及
Train:
dataset:
name: SimpleDataSet
data_dir: /home/aistudio/dataset/train/image
ext_op_transform_idx: 1
label_file_list:
- /home/aistudio/dataset/train/train_rec_gt_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
- RecAug:
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
# 修改数据及
Eval:
dataset:
name: SimpleDataSet
data_dir: /home/aistudio/dataset/train/image
ext_op_transform_idx: 1
label_file_list:
- /home/aistudio/dataset/train/eval_rec_gt_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
# 拷贝配置好的文件到指定位置
%cd ~
!cp ~/ch_PP-OCRv3_rec_distillation.yml ~/PaddleOCR/configs/rec/PP-OCRv3/
3.模型训练
%cd ~/PaddleOCR/
#多卡训练,通过--gpus参数指定卡号
!python -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
七、联推理串
1.模型导出
# 导出检测模型
!python tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./my_exps/det/best_accuracy Global.save_inference_dir=./inference/det
# 导出识别模型
!python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model=./my_exps/rec/best_accuracy Global.save_inference_dir=./inference/rec
2.联推理串
! python tools/infer/predict_system.py \
--det_model_dir=inference/det \
--rec_model_dir=inference/rec \
--image_dir="/home/aistudio/dataset/train/image/image_0.jpg" \
--rec_image_shape=3,48,320
# show img
plt.figure(figsize=(10, 8))
img = plt.imread("./inference_results/test.jpg")
plt.imshow(img)
- 点赞
- 收藏
- 关注作者
评论(0)