手把手教你打比赛-螺母缺陷检测

举报
李长安 发表于 2023/02/24 16:23:56 2023/02/24
【摘要】 本次教程将带领大家完整的走一遍比赛流程。经过前面一系列教程,相信大家对飞桨已经熟悉了,但是我们不能一直纸上谈兵,所以本次就带领大家完成一个实际的比赛项目。

手把手教你打比赛-螺母缺陷检测

本次教程将带领大家完整的走一遍比赛流程。经过前面一系列教程,相信大家对飞桨已经熟悉了,但是我们不能一直纸上谈兵,所以本次就带领大家完成一个实际的比赛项目。

1 基础知识

本次比赛采用的baseline模型为mobilnetv2,这是一个轻量级模型。来自论文MobileNetV2: Inverted Residuals and Linear Bottlenecks

MobileNetV2是在V1基础之上的改进。V1主要思想就是深度可分离卷积。V2的创新点主要包括Linear Bottleneck 和 Inverted Residuals两个部分。

这里仅对该网络做简要介绍,如果想要深入研究请自行查阅资料,或者在评论区评论、交流。

2 数据分析

这个比赛数据量较少,共分为两个类别:neg、pos,分别为有缺陷的图像和正常图像,其数量都为200张。图像示例如下所示:

neg pos

3 解题思路

问题

(1)数据量较少,使用大网络会非常容易造成过拟合问题。

(2)正类和负类较为相似,可以尝试细粒度识别。

解决

(1)使用数据集扩增,包括旋转、水平翻转、随机剪裁等。或者使用GAN的方式进行数据增强。

(2)使用轻量级网络。

4 实战演练

4.1 解压数据集

!cd 'data/data64280' && unzip -q trainset.zip

4.2 数据集信息生成

运行data_pre.py文件,生成包含数据信息的txt文件。代码如下所示:

import os


all_file_dir = '/home/aistudio/data/data64280/trainset'

f = open('/home/aistudio/data/data64280/train.txt', 'w')

label_id = 0

class_list = [c for c in os.listdir(all_file_dir) if os.path.isdir(os.path.join(all_file_dir, c))]

# print(class_list)
for class_dir in class_list:


    image_path_pre = os.path.join(all_file_dir, class_dir)

    for img in os.listdir(image_path_pre):
        # print(img)
        f.write("{0}\t{1}\n".format(os.path.join(image_path_pre, img), label_id))

    label_id += 1


    label_id += 1

4.3 开启训练

运行train.py文件开启训练。

!cd 'work/' && python train.py

训练结果

5 代码详解

5.1 自定义数据集

此部分根据主要参考了官网上数据集的实现以及【动手学Paddle2.0系列】手把手教你自定义数据集。代码中已经添加了详细注释,如有任何疑问,欢迎在评论区交流。

class MyDataset(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, txt, transform=None):
        """
        步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
        """
        super(MyDataset, self).__init__()
        imgs = []
        f = open(txt, 'r')
        for line in f:
            line = line.strip('\n')
            line = line.rstrip('\n')
            words = line.split()
            imgs.append((words[0], int(words[1])))
            self.imgs = imgs
            self.transform = transform
            # self.loader = loader
    def __getitem__(self, index):  # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]
        # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        img = Image.open(fn)
        img = img.convert("RGB")

        img =  np.array(img).astype('float32')
        # 归一化
        img *= 0.007843 
        label = np.array([label]).astype(dtype='int64')
        # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img)
            # 数据标签转换为Tensor
        return img, label
        # return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
        # **********************************  #使用__len__()初始化一些需要传入的参数及数据集的调用**********************

    def __len__(self):
        # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)

5.2 数据读取及预处理

对数据进行了简单的预处理,使用了paddle2.0中封装的高层API。相关API的详细介绍传送门

需要注意的是,在AIStudio平台中, paddle.io.DataLoader中的 num_workers=0参数如果不是默认的0,设置为8等其他数,则会报错。

transform = T.Compose([
    T.RandomResizedCrop([448, 448]),
    T.RandomHorizontalFlip(),
    T.RandomRotation(90),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),

])

train_dataset = MyDataset(txt='/home/aistudio/data/data64280/train.txt', transform=transform)

train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=8, shuffle=True)

5.3 模型训练

本部分直接使用了paddle2.0中封装好的高层API进行迭代训练,相关API详细讲解,传送门

# build model
model = mobilenet_v2(pretrained=True,scale=1.0, num_classes=2, with_pool=True)

# 调用飞桨框架的VisualDL模块,保存信息到目录中。
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')

model = paddle.Model(model)
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
model.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy(topk=(1, 2))
    )

model.fit(train_loader,
        epochs=2,
        verbose=1,
        )

总结

训练好的模型在data/data64280/model_路径下,将该文件打包下载。替换提交样例中的文件,打包提交即可。

提交结果示例

比赛地址

提升

(1)大家可以调整迭代次数,增加迭代次数。

(2)对数据集进行扩增。

(3)使用目标检测的方法。

大家加油!

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。