AI实战 | Tensorflow自定义数据集和迁移学习(附代码下载)

举报
小小谢先生 发表于 2022/04/14 00:07:26 2022/04/14
【摘要】 自定义数据集 做深度学习项目时,我们一般都不用网上公开的数据集,而是用自己制作的数据集。那么,怎么用Tensorflow2.0来制作自己的数据集并把数据喂给神经网络呢?且看这篇文章慢慢道来。 Pokemon Datasets 这篇文章我们用的datasets是Pokemon datasets,也就是皮卡丘电影中的一些角色,如下图所...

自定义数据集

做深度学习项目时,我们一般都不用网上公开的数据集,而是用自己制作的数据集。那么,怎么用Tensorflow2.0来制作自己的数据集并把数据喂给神经网络呢?且看这篇文章慢慢道来。

Pokemon Datasets

这篇文章我们用的datasets是Pokemon datasets,也就是皮卡丘电影中的一些角色,如下图所示:

数据集

数据集

数据集下载

链接: https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw

提取码:dsxl

数据集划分

划分

划分

由上图可知,60%的数据集用来train,20%的数据集用来validation,同样20%用来test

四个步骤

  • Load data:加载数据

  • Build model:建立模型

  • Train-Val-Test:训练和测试

  • Transfer Learning:迁移模型

加载数据

首先对数据进行预处理,把像素值的Numpy类型转换为Tensor类型,并归一化到[0~1]。把数据集的标签做one-hot编码。


  
  1. def preprocess(x,y):
  2. # x: 图片的路径,y:图片的数字编码
  3. x = tf.io.read_file(x)
  4. x = tf.image.decode_jpeg(x, channels=3) # RGBA
  5. x = tf.image.resize(x, [244, 244])
  6. return x, y

数据集标准处理流程

代码中load_pokemon用的是自己的数据集写的代码,具体可阅读pokemon.py文件。


  
  1. # 创建训练集Datset对象
  2. images, labels, table = load_pokemon('pokemon',mode='train')
  3. db_train = tf.data.Dataset.from_tensor_slices((images, labels))
  4. db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
  5. # 创建验证集Datset对象
  6. images2, labels2, table = load_pokemon('pokemon',mode='val')
  7. db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
  8. db_val = db_val.map(preprocess).batch(batchsz)
  9. # 创建测试集Datset对象
  10. images3, labels3, table = load_pokemon('pokemon',mode='test')
  11. db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
  12. db_test = db_test.map(preprocess).batch(batchsz)

图片数据增强及标准化

一般数据集较少的话需要使用数据增强以增加数据集,防止训练网络过拟合。比如旋转角度、裁剪等,并归一化到[0~1]。把数据集的标签做one-hot编码。所示代码如下:


  
  1. # x = tf.image.random_flip_left_right(x)
  2. x = tf.image.random_flip_up_down(x)
  3. x = tf.image.random_crop(x, [224,224,3])
  4. # x: [0,255]=> -1~1
  5. x = tf.cast(x, dtype=tf.float32) / 255.
  6. x = normalize(x)
  7. y = tf.convert_to_tensor(y)
  8. y = tf.one_hot(y, depth=5)

建立网络

神经网络从零开始训练,backbone用李沐大神的resnet网络。详细代码请查看resnet.py文件。部分代码如下:


  
  1. class ResNet(keras.Model):
  2. def __init__(self, num_classes, initial_filters=16, **kwargs):
  3. super(ResNet, self).__init__(**kwargs)
  4. self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
  5. self.blocks = keras.models.Sequential([
  6. ResnetBlock(initial_filters * 2, strides=3),
  7. ResnetBlock(initial_filters * 2, strides=1),
  8. # layers.Dropout(rate=0.5),
  9. ResnetBlock(initial_filters * 4, strides=3),
  10. ResnetBlock(initial_filters * 4, strides=1),
  11. ResnetBlock(initial_filters * 8, strides=2),
  12. ResnetBlock(initial_filters * 8, strides=1),
  13. ResnetBlock(initial_filters * 16, strides=2),
  14. ResnetBlock(initial_filters * 16, strides=1),
  15. ])
  16. self.final_bn = layers.BatchNormalization()
  17. self.avg_pool = layers.GlobalMaxPool2D()
  18. self.fc = layers.Dense(num_classes)
  19. def call(self, inputs, training=None):
  20. # print('x:',inputs.shape)
  21. out = self.stem(inputs,training=training)
  22. out = tf.nn.relu(out)
  23. # print('stem:',out.shape)
  24. out = self.blocks(out, training=training)
  25. # print('res:',out.shape)
  26. out = self.final_bn(out, training=training)
  27. # out = tf.nn.relu(out)
  28. out = self.avg_pool(out)
  29. # print('avg_pool:',out.shape)
  30. out = self.fc(out)
  31. # print('out:',out.shape)
  32. return out

训练和测试

部分代码如下:


  
  1. resnet = keras.Sequential([
  2. layers.Conv2D(16,5,3),
  3. layers.MaxPool2D(3,3),
  4. layers.ReLU(),
  5. layers.Conv2D(64,5,3),
  6. layers.MaxPool2D(2,2),
  7. layers.ReLU(),
  8. layers.Flatten(),
  9. layers.Dense(64),
  10. layers.ReLU(),
  11. layers.Dense(5)
  12. ])
  13. resnet = ResNet(5)
  14. resnet.build(input_shape=(4, 224, 224, 3))
  15. resnet.summary()
  16. early_stopping = EarlyStopping(
  17. monitor='val_accuracy',
  18. min_delta=0.001,
  19. patience=5
  20. )
  21. resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
  22. loss=losses.CategoricalCrossentropy(from_logits=True),
  23. metrics=['accuracy'])
  24. resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
  25. callbacks=[early_stopping])
  26. resnet.evaluate(db_test)

迁移网络学习

网络可以丛零开始训练,也可以从别的训练好的参数模型迁移过来,本次实战用Tensorflow预训练的vgg19模型来加载训练,从而加快训练过程。

迁移学习的原理如下图所示:

部分代码如下:


  
  1. net = keras.applications.VGG19(weights='imagenet', include_top=False,
  2. pooling='max')
  3. net.trainable = False
  4. newnet = keras.Sequential([
  5. net,
  6. layers.Dense(5)
  7. ])
  8. newnet.build(input_shape=(4,224,224,3))
  9. newnet.summary()
  10. early_stopping = EarlyStopping(
  11. monitor='val_accuracy',
  12. min_delta=0.001,
  13. patience=5
  14. )
  15. newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
  16. loss=losses.CategoricalCrossentropy(from_logits=True),
  17. metrics=['accuracy'])
  18. newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
  19. callbacks=[early_stopping])
  20. newnet.evaluate(db_test)

代码下载

本篇文章完整代码在公众号”计算机视觉cv“对话框回复 “pokemon” 就可得到百度云链接,建议直接复制再去公众号回复。

参考资料

本篇文章主要参考网易云课堂龙龙老师的《深度学习与TensorFlow 2入门实战》

课程链接:https://study.163.com/course/courseMain.htm?courseId=1209092816&share=1&shareId=1026182418

文章来源: blog.csdn.net,作者:小小谢先生,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/xiewenrui1996/article/details/106368280

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。