文本检测——CTPN模型

举报
HWCloudAI 发表于 2022/12/05 14:42:11 2022/12/05
【摘要】 文本检测——CTPN模型在本案例中,我们将继续学习深度学习中的OCR(Optical Character Recognition)光学字符识别技术。OCR作为计算机视觉中较早使用深度学习技术的领域,有很多优秀的模型出现,所以通过此案例我们来学习深度学习下的OCR技术。普遍的深度学习下的OCR技术将文字识别过程分为:文本区域检测以及字符识别。本案例中介绍的模型CTPN就是一种文本检测模型,它...

文本检测——CTPN模型

在本案例中,我们将继续学习深度学习中的OCR(Optical Character Recognition)光学字符识别技术。OCR作为计算机视觉中较早使用深度学习技术的领域,有很多优秀的模型出现,所以通过此案例我们来学习深度学习下的OCR技术。普遍的深度学习下的OCR技术将文字识别过程分为:文本区域检测以及字符识别。本案例中介绍的模型CTPN就是一种文本检测模型,它将图片中的文字部分检测出来。

注意事项:

  1. 本案例使用框架**:** TensorFlow-1.13.1

  2. 本案例使用硬件规格**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB

  3. 进入运行环境方法:点此链接进入AI Gallery,点击Run in ModelArts按钮进入ModelArts运行环境,如需使用GPU,您可以在ModelArts JupyterLab运行界面右边的工作区进行切换

  4. 运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

  5. JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》

  6. 碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.数据和代码下载

运行下面代码,进行数据和代码的下载和解压

import os
from modelarts.session import Session
sess = Session()

if sess.region_name == 'cn-north-1':
    bucket_path="modelarts-labs/notebook/DL_ocr_ctpn_text_detection/ctpn.tar"
elif sess.region_name == 'cn-north-4':
    bucket_path="modelarts-labs-bj4/notebook/DL_ocr_ctpn_text_detection/ctpn.tar"
else:
    print("请更换地区到北京一或北京四")

if not os.path.exists('./CTPN'):
    sess.download_data(bucket_path=bucket_path, path="./ctpn.tar")
Successfully download file modelarts-labs-bj4/notebook/DL_ocr_ctpn_text_detection/ctpn.tar from OBS to local ./ctpn.tar

2.解压文件

if os.path.exists('./ctpn.tar') and (not os.path.exists('./CTPN')):
    os.system("tar -xf ctpn.tar")
if os.path.exists('./ctpn.tar'):
    os.system("rm ./ctpn.tar")

代码中的nms部分以及bbox部分由C语言代码完成,利用Python的扩展模块Cython进行编译。使用C语言进行实现可以提升运行速度,Cython为C语言和Python混合编译的实现模块。所以我们首先编译C语言部分。

import os
pwd = os.getcwd() 
os.chdir('./CTPN/utils/bbox') 
!chmod +x make.sh
!./make.sh
os.chdir(pwd)
Compiling bbox.pyx because it depends on /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/Cython/Includes/numpy/__init__.pxd.

[1/1] Cythonizing bbox.pyx

running install

running build

running build_ext

building 'bbox' extension

creating build

creating build/temp.linux-x86_64-3.6

gcc -pthread -B /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include -I/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/include/python3.6m -c bbox.c -o build/temp.linux-x86_64-3.6/bbox.o

In file included from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/ndarraytypes.h:1822:0,

                 from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,

                 from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/arrayobject.h:4,

                 from bbox.c:580:

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]

 #warning "Using deprecated NumPy API, disable it with " \

  ^

creating build/lib.linux-x86_64-3.6

gcc -pthread -shared -B /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/compiler_compat -L/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib -Wl,-rpath=/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.6/bbox.o -o build/lib.linux-x86_64-3.6/bbox.cpython-36m-x86_64-linux-gnu.so

running install_lib

copying build/lib.linux-x86_64-3.6/bbox.cpython-36m-x86_64-linux-gnu.so -> /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages

running install_egg_info

Writing /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/UNKNOWN-0.0.0-py3.6.egg-info

Compiling nms.pyx because it depends on /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/Cython/Includes/numpy/__init__.pxd.

[1/1] Cythonizing nms.pyx

running install

running build

running build_ext

building 'nms' extension

gcc -pthread -B /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include -I/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/include/python3.6m -c nms.c -o build/temp.linux-x86_64-3.6/nms.o

In file included from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/ndarraytypes.h:1822:0,

                 from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,

                 from /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/arrayobject.h:4,

                 from nms.c:580:

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]

 #warning "Using deprecated NumPy API, disable it with " \

  ^

gcc -pthread -shared -B /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/compiler_compat -L/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib -Wl,-rpath=/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.6/nms.o -o build/lib.linux-x86_64-3.6/nms.cpython-36m-x86_64-linux-gnu.so

running install_lib

copying build/lib.linux-x86_64-3.6/nms.cpython-36m-x86_64-linux-gnu.so -> /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages

running install_egg_info

Removing /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/UNKNOWN-0.0.0-py3.6.egg-info

Writing /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/UNKNOWN-0.0.0-py3.6.egg-info

3.开始案例,首先我们引用相关和库和方法

import shutil
import cv2
import numpy as np
import datetime
import os
import sys
import time
from PIL import Image
import tensorflow as tf
sys.path.append(os.getcwd())
from tensorflow.contrib import slim

from CTPN import data_provider as data_provider
from CTPN.model import mean_image_subtraction,Bilstm,lstm_fc,loss
from CTPN import vgg
from CTPN import model
from CTPN.utils.rpn_msr.proposal_layer import proposal_layer
from CTPN.utils.text_connector.detectors import TextDetector
from CTPN.utils.image import resize_image
/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint16 = np.dtype([("qint16", np.int16, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  _np_qint32 = np.dtype([("qint32", np.int32, 1)])

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

  np_resource = np.dtype([("resource", np.ubyte, 1)])

4.数据预处理

4.1定义变量

  • checkpoint_path:checkpoint文件存储路径
  • vgg_path:vgg模型checkpoint文件存储路径
  • image_path:测试图片存储路径
checkpoint_path = './models/checkpoints/'  # 训练模型保存路径
vgg_path = "./models/vgg_16.ckpt"          # vgg16预训练模型
image_path = './data/CTW-200'              # 训练集图片路径

CHECKPOINT_PATH = './models/checkpoints'   # 测试模型保存路径
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

修改图片尺寸

input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
input_bbox = tf.placeholder(tf.float32, shape=[None, 5], name='input_bbox')
input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
# 定义并初始化变量global_step 
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
# 定义变量learning_rate
learning_rate = tf.Variable(1e-5, trainable=False)
# Adam优化器进行优化,学习率为1e-5,非固定值,可以根据学习情况进行设置
opt = tf.train.AdamOptimizer(learning_rate)
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.

Instructions for updating:

Colocations handled automatically by placer.
image = mean_image_subtraction(input_image)
# VGG16基础
with slim.arg_scope(vgg.vgg_arg_scope()):
    conv5_3 = vgg.vgg_16(image)
# 加入卷积层
rpn_conv = slim.conv2d(conv5_3, 512, 3)
# 加入双向LSTM
lstm_output = Bilstm(rpn_conv, 512, 128, 512, scope_name='BiLSTM')
# 加入全连接层
# 预测bounding box
bbox_pred = lstm_fc(lstm_output, 512, 10 * 4, scope_name="bbox_pred")
# 预测文字或者非文字类别的分数
cls_pred = lstm_fc(lstm_output, 512, 10 * 2, scope_name="cls_pred")
WARNING:tensorflow:From /home/ma-user/work/CTPN/model.py:30: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.

Instructions for updating:

This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.

WARNING:tensorflow:From /home/ma-user/work/CTPN/model.py:33: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.

Instructions for updating:

Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API

WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:443: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.

Instructions for updating:

Please use `keras.layers.RNN(cell)`, which is equivalent to this API

4.2对预测结果利用softmax函数进行归一化计算

cls_pred_shape = tf.shape(cls_pred)
cls_pred_reshape = tf.reshape(cls_pred, [cls_pred_shape[0], cls_pred_shape[1], -1, 2])
cls_pred_reshape_shape = tf.shape(cls_pred_reshape)
cls_prob = tf.reshape(tf.nn.softmax(tf.reshape(cls_pred_reshape, [-1, cls_pred_reshape_shape[3]])),
                      [-1, cls_pred_reshape_shape[1], cls_pred_reshape_shape[2], cls_pred_reshape_shape[3]],
                      name="cls_prob")

4.3定义损失函数

total_loss, model_loss, rpn_cross_entropy, rpn_loss_box = loss(bbox_pred, cls_pred, input_bbox, input_im_info)
batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, '0'))
grads = opt.compute_gradients(total_loss)
summary_op = tf.summary.merge_all()
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
    train_op = tf.no_op(name='train_op')
WARNING:tensorflow:From /home/ma-user/work/CTPN/model.py:95: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.

Instructions for updating:

tf.py_func is deprecated in TF V2. Instead, use

    tf.py_function, which takes a python function which manipulates tf eager

    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to

    an ndarray (just call tensor.numpy()) but having access to eager tensors

    means `tf.py_function`s can use accelerators such as GPUs as well as

    being differentiable using a gradient tape.

    

WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Use tf.cast instead.

/home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

5.加载预训练VGG模型参数

init = tf.global_variables_initializer()
variable_restore_op = slim.assign_from_checkpoint_fn(vgg_path,
                                                     slim.get_trainable_variables(),
                                                     ignore_missing_vars=True)
config = tf.ConfigProto(allow_soft_placement=True)
WARNING:tensorflow:Variable Conv/weights missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable Conv/biases missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/bidirectional_rnn/fw/lstm_cell/kernel missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/bidirectional_rnn/fw/lstm_cell/bias missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/bidirectional_rnn/bw/lstm_cell/kernel missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/bidirectional_rnn/bw/lstm_cell/bias missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/weights missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable BiLSTM/biases missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable bbox_pred/weights missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable bbox_pred/biases missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable cls_pred/weights missing in checkpoint ./models/vgg_16.ckpt

WARNING:tensorflow:Variable cls_pred/biases missing in checkpoint ./models/vgg_16.ckpt

6.开始训练

with tf.Session(config=config) as sess:
    #加载预训练模型的checkpoint数据
    ckpt = tf.train.latest_checkpoint(checkpoint_path)
    #从50000步开始训练
    restore_step = 50000
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
    saver.restore(sess, ckpt)
    #读取训练数据
    data_generator = data_provider.get_batch(num_workers=4,vis=False,image_path = image_path)
    #开始训练 50000-51000步
    for step in range(restore_step, 50030):
        data = next(data_generator)
        ml, tl, _, summary_str = sess.run([model_loss, total_loss, train_op, summary_op],
                                          feed_dict={input_image: data[0],
                                                     input_bbox: data[1],
                                                     input_im_info: data[2]})
        #设置学习率衰减规则,以30000步为单位衰减
        if step != 0 and step % 30000 == 0:
            sess.run(tf.assign(learning_rate, learning_rate.eval() * 0.1))
        #设置checkpoint存储规则,每2000步存储一次
        if (step + 1) % 2000 == 0:
            filename = ('ctpn_{:d}'.format(step + 1) + '.ckpt')
            filename = os.path.join(checkpoint_path, filename)
            saver.save(sess, filename)
            print('Write model to: {:s}'.format(filename))
        #每训练10步打印一次步数信息
        if step%10 == 0:
            print('train step'+str(step))
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.

Instructions for updating:

Use standard file APIs to check for files with this prefix.


Find 200 images

Find 200 images

Find 200 images

Find 200 images

train step50000

train step50010

train step50020

7.测试部分

tf.reset_default_graph()
# 定义模型输入信息占位符
input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
init_op = tf.initialize_all_variables()
# 定义模型训练步骤数
global_step = tf.variable_scope('global_step', [], initializer=tf.constant_initializer(0))
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:193: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.

Instructions for updating:

Use `tf.global_variables_initializer` instead.
# 加载预训练模型
bbox_pred, cls_pred, cls_prob = model.model(input_image)
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
# 将变量存储到saver中
saver = tf.train.Saver(variable_averages.variables_to_restore())

预测图片中文字区域位置。加载模型权重文件,将大小调整好的测试图片传入模型,获取预测值。
将测试图片放入路径img_path

ctpn_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
with ctpn_sess.as_default():
    # 加载预训练模型权重信息
    ckpt_state = tf.train.get_checkpoint_state(CHECKPOINT_PATH)
    model_path = os.path.join(CHECKPOINT_PATH, os.path.basename(ckpt_state.model_checkpoint_path))
    saver.restore(ctpn_sess, model_path)
#加载测试图片
img_path = './img_6.jpg'
#对图片进行形状调整
im = cv2.imread(img_path)[:, :, ::-1]
img, (rh, rw) = resize_image(im)
h, w, c = img.shape
im_info = np.array([h, w, c]).reshape([1, 3])
#将图片信息传入模型得出预测结果,分别为文字区域坐标以及其得分
bbox_pred_val, cls_prob_val = ctpn_sess.run([bbox_pred, cls_prob],feed_dict={input_image: [img],input_im_info: im_info})
textsegs_total, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info)

此时文本框有对应的预测坐标以及分数,下图将分数大于0.7的预测框绘制出来,可以看到文本框几乎覆盖了文本区域。

img = Image.open(img_path)
img, _  = resize_image(np.array(img))
img = np.array(img)
for i in textsegs_total:
    if i[0] >= 0.7:
        cv2.rectangle(img,(i[1],i[2]),(i[3],i[4]),(255,0,0),1)
img = Image.fromarray(img)
img

对文本位置进行优化,对之前定位出的小矩形框进行合并,向合并后的文本框绘制在图片中。文本框合并策略由5个参数进行配置。

# 重新打开图片准备绘制
from CTPN.utils.text_connector.text_connect_cfg import Config as TextLineCfg
img = Image.open(img_path)
img, _  = resize_image(np.array(img))
img = np.array(img)
scores = textsegs_total[:, 0]
textsegs = textsegs_total[:, 1:5]

# 文本框合并策略:
TextLineCfg.MAX_HORIZONTAL_GAP = 50          # 两个框之间的距离小于50,才会被判定为临近框。该值越小,两个框之间要进行合并的要求就越高
TextLineCfg.TEXT_PROPOSALS_MIN_SCORE = 0.7   # 单个小文本框的置信度,高于这个置信度的框才会被合并。该值越大,越多的框就会被丢弃掉
TextLineCfg.TEXT_PROPOSALS_NMS_THRESH = 0.2  # 非极大值抑制阈值。该值越大,越多的框就会被丢弃掉
TextLineCfg.MIN_V_OVERLAPS = 0.7             # 两个框之间的垂直重合度大于0.7,才会被判定为临近框。该值越大,两个在垂直方向上有偏差的框进行合并的可能性就越小
textdetector = TextDetector(DETECT_MODE='H') # DETECT_MODE有两种取值:'H'和'O','H'模式适合检测水平文字,'O'模式适合检测有轻微倾斜的文字

boxes = textdetector.detect(textsegs, scores[:, np.newaxis], img.shape[:2])
boxes = np.array(boxes, dtype=np.int)
# 绘制文字区域
for i, box in enumerate(boxes):
    cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 0, 0), thickness=2)
img = Image.fromarray(img)
img

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。