深度学习修炼(二)——数据集的加载
致谢
2 数据集的加载
与sklearn中的datasets自带数据集类似,pytorch框架也为我们提供了数据集以便一系列的模型测试。其数据集作为一个类继承自父类torch.utils.data.Dataset。
2.1 框架数据集的加载
让我们看看torch为我们提供了什么数据集。数据集种类如下所示:
-
手写字符识别:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot
-
实物分类:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet
-
人脸识别:CelebA
-
场景分类:LSUN、Places365
-
用于object detection:SVHN、VOCDetection、COCODetection
-
用于semantic/instance segmentation:
-
语义分割:Cityscapes、VOCSegmentation
-
语义边界:SBD
-
用于image captioning:Flickr、COCOCaption
-
用于video classification:HMDB51、Kinetics
-
用于3D reconstruction:PhotoTour
-
用于shadow detectors:SBU
以FashionMNIST数据集为例,我们看一下如何加载数据集。
torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())
root
是存储训练/测试数据的路径train
指定训练或测试数据集,当布尔值为True则为训练集,当布尔值为False则为测试集download=True
从互联网下载数据(如果无法在本地获得)transform
指定特征转换方式,target_transform
指定标签转换方式
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
def load_data():
"""加载数据集"""
# 1 训练数据集的加载
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
# 2 测试数据集的加载
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
return train_data, test_data
train_data, test_data = load_data()
print(train_data)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
数据集加载完实际上是以类的形式存在的,其不同于sklearn中返回的Bunch。
如果我们想要看看数据集中有啥要怎么做呢?首先,这个数据集是图像分类数据集,说明里面含有的都是图像,为此,我们可以使用subplots存放这些图片。对于这些数据集,我们可以像列表一样手动索引。如train_data[index]
。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
def load_data():
"""加载数据集"""
# 1 训练数据集的加载
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
# 2 测试数据集的加载
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
return train_data, test_data
def show_data(train_data):
"""数据集可视化"""
label_map = {
0: "T_Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# 从训练集中随机抽出九张图(九个样本)
for i in range(1, cols * rows + 1):
# 设置索引,索引取值为0到训练集的长度
sample_idx = torch.randint(len(train_data), size=(1,)).item()
# 取出对应样本的图片和标签
img, label = train_data[sample_idx]
# 依次画于事先指定的九宫格图上
figure.add_subplot(rows, cols, i)
# 设置对应图片的标题
plt.title(label_map[label])
# 关掉坐标轴
plt.axis("off")
# 展示图片
plt.imshow(img.squeeze(), cmap="gray")
# 释放画布
plt.show()
train_data, test_data = load_data()
show_data(train_data)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
out:
上面用到了一个API,即torch.randint()
torch.
randint
(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
- 用于取随机整数,返回值为张量
- low:int类型,表明要从分布中提取的最低整数
- high:int类型,表明要从分布中提取的最高整数1
- size:元组类型,表明输出张量的形状
- dtype:返回值张量的数据类型
- device:返回张量所需的设备
- requires_grad:布尔类型,表明是否要对返回的张量自动求导。
如:
torch.randint(3, 5, (3,)) tensor([4, 3, 4])
- 1
- 2
意味生成一个一维的3元素向量,其中向量中的元素取值从3-5取。
2.2 自定义数据集
如果你不想使用框架自带的数据集,那么你可以自己定义一个数据集类。自定义Dataset类必须实现三个函数:__ init __ 、 __ len __ 、__ getitem __。其中图像部分存储于一个文件夹中,标签单独存储在CSV文件中。
在接下来的代码中,让我们看看如何创建一个自定义数据集。
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
对于__ init __ 函数来说,包含加载图像、注释文件和两个转换的目录,在这里我们不做过多讲解,后面会详细介绍。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
- 1
- 2
- 3
- 4
- 5
对于__ len __ 函数,其功能是返回数据集中的样本数。
def __len__(self):
return len(self.img_labels)
- 1
- 2
对于 __ getitem __,其功能是给定索引便能返回对应样本。
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
在自定义这一部分不用过多的去了解,用着用着就会了,就算不会代码也是通用,需要用的时候看一下复制一下,别搞得自己这么焦虑。
2.3 准备数据以进行数据加载器训练
在pytorch中,数据加载的核心实际上是torch.utils.data.DataLoader类,它支持对torch数据集的python可迭代,换而言之,DataLoader相当于你拿一个水盆,而dataset相当于泉水。DataLoader可以对小批量数据集进行处理,处理内容包括:
- 地图样式和可迭代样式的数据集
- 自定义数据集加载顺序
- 多进程加载数据
- 自动内存固定
其中地图样式数据集是指自定义数据集,而可迭代样式数据集指的是自带数据集。其他详情对于初学者来说很不友好,这里不做过多解释,你可以理解为这就是个科普知识。
我们来看一下这个API吧。
torch.utils.data.DataLoader(数据集, batch_size=1, shuffle=False)
- 用于加载样本并且进行批处理
- 数据集:要加载的数据集
- batch_size:整数类型,表明每批要加载的样本数,默认为1
- shuffle:布尔类型,表明是否要洗牌
我们利用上面的API来加载我们上面的Fashion_MNIST吧。
def load_batch_data():
"""数据集批处理加载器"""
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
return train_dataloader, test_dataloader
- 1
- 2
- 3
- 4
- 5
既然已经将样本导入加载器,那么我们如何从加载器中读取数据呢?我们可以根据需要循环访问数据集。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
def load_data():
"""加载数据集"""
# 1 训练数据集的加载
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
# 2 测试数据集的加载
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
return train_data, test_data
def show_data(train_data):
"""数据集可视化"""
label_map = {
0: "T_Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# 从训练集中随机抽出九张图(九个样本)
for i in range(1, cols * rows + 1):
# 设置索引,索引取值为0到训练集的长度
sample_idx = torch.randint(len(train_data), size=(1,)).item()
# 取出对应样本的图片和标签
img, label = train_data[sample_idx]
# 依次画于事先指定的九宫格图上
figure.add_subplot(rows, cols, i)
# 设置对应图片的标题
plt.title(label_map[label])
# 关掉坐标轴
plt.axis("off")
# 展示图片
plt.imshow(img.squeeze(), cmap="gray")
# 释放画布
plt.show()
def load_batch_data():
"""数据集批处理加载器"""
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
return train_dataloader, test_dataloader
def show_batch_data():
"""循环访问数据加载器"""
train_dataloader, test_dataloader = load_batch_data()
train_feature, train_labels = next(iter(train_dataloader))
print(f"特征大小:{train_feature.size()}")
print(f"标签大小:{train_labels.size()}")
img = train_feature[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"label:{label}")
train_data, test_data = load_data()
# show_data(train_data)
show_batch_data()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
文章来源: blog.csdn.net,作者:ArimaMisaki,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/chengyuhaomei520/article/details/123496767
- 点赞
- 收藏
- 关注作者
评论(0)