【MindSpore第七期两日集群营】体验MindSpore数据增强API统一
为什么要做数据增强?
Cifar数据集有60000张图片,其中50000张训练图片,10000张测试图片,如果用这样的训练集来训练参数较多的网络时,就容易出现过拟合的现象。所以,我们要想办法对原有的训练集做数据增强,目的是提高训练图片的丰富度,减少过拟合,提高网络的泛化性。
这次我们用本地的WSL环境做实验。
打开wsl ubuntu 18.04。
下载cifar10数据集:
mkdir ~/dataset
cd ~/dataset
wget http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
解压:
tar -zxvf cifar-10-binary.tar.gz
先看一下数据集的目录结构,需要安装tree软件:
sudo apt install tree
查看下具体结构:
tree cifar-10-batches-bin/
python
import mindspore
import mindspore.dataset as ds
好像没问题,数据集相关包也导入正常。
编辑以下文件:
vi load_cifar.py
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
# Pad到(36, 36),随机Crop到(32, 32)
random_crop = vision.RandomCrop((32, 32), (4, 4, 4, 4))
# Resize到(128, 128)
resize = vision.Resize((512, 512))
# 随机翻转
random_horizontal = vision.RandomHorizontalFlip()
# 随机调整颜色
random_color = vision.RandomColorAdjust(brightness=(0.8, 1), contrast=(0.8, 1), saturation=(0.3, 1))
# 定义数据流水线
dataset_dir = "cifar-10-batches-bin"
dataloader = ds.Cifar10Dataset(dataset_dir)
dataloader = dataloader.map([random_crop, resize, random_horizontal, random_color], input_columns="image")
# 启动数据预处理
count = 0
for data in dataloader.create_dict_iterator(output_numpy=True):
image = data["image"]
cv2.imshow("win", image)
cv2.waitKey()
count += 1
if count > 5:
break
执行下看看:
python load_cifar.py
报错:没有装opencv。
那就装一下:
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
再重新执行:
执行:
好像显示了一些错,但是python脚本中前6张图片也显示出来了。从Python的代码来看,它对原图片进行了随机翻转,随机变色,随机Crop。。。而只是用了一句dataloader的map方法就搞定了。。。
是不是有点太方便了。。。真的不需要人工干预一下吗?
(全文完,谢谢阅读)
- 点赞
- 收藏
- 关注作者
评论(0)