深度学习修炼(二)——数据集的加载

举报
ArimaMisaki 发表于 2022/08/09 01:06:38 2022/08/09
【摘要】 文章目录 致谢 2 数据集的加载2.1 框架数据集的加载2.2 自定义数据集2.3 准备数据以进行数据加载器训练 致谢 Pytorch自带数据集介绍_godblesstao的...

致谢

Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_pytorch自带数据集

2 数据集的加载

与sklearn中的datasets自带数据集类似,pytorch框架也为我们提供了数据集以便一系列的模型测试。其数据集作为一个类继承自父类torch.utils.data.Dataset。

2.1 框架数据集的加载

让我们看看torch为我们提供了什么数据集。数据集种类如下所示:

  • 手写字符识别:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot

  • 实物分类:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet

  • 人脸识别:CelebA

  • 场景分类:LSUN、Places365

  • 用于object detection:SVHN、VOCDetection、COCODetection

  • 用于semantic/instance segmentation:

  • 语义分割:Cityscapes、VOCSegmentation

  • 语义边界:SBD

  • 用于image captioning:Flickr、COCOCaption

  • 用于video classification:HMDB51、Kinetics

  • 用于3D reconstruction:PhotoTour

  • 用于shadow detectors:SBU

以FashionMNIST数据集为例,我们看一下如何加载数据集。

torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())

  • root是存储训练/测试数据的路径
  • train指定训练或测试数据集,当布尔值为True则为训练集,当布尔值为False则为测试集
  • download=True从互联网下载数据(如果无法在本地获得)
  • transform指定特征转换方式,target_transform指定标签转换方式
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor


def load_data():
    """加载数据集"""
    # 1 训练数据集的加载
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    # 2 测试数据集的加载
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return train_data, test_data


train_data, test_data = load_data()
print(train_data)

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

数据集加载完实际上是以类的形式存在的,其不同于sklearn中返回的Bunch。

如果我们想要看看数据集中有啥要怎么做呢?首先,这个数据集是图像分类数据集,说明里面含有的都是图像,为此,我们可以使用subplots存放这些图片。对于这些数据集,我们可以像列表一样手动索引。如train_data[index]

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


def load_data():
    """加载数据集"""
    # 1 训练数据集的加载
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    # 2 测试数据集的加载
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return train_data, test_data


def show_data(train_data):
    """数据集可视化"""
    label_map = {
        0: "T_Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    # 从训练集中随机抽出九张图(九个样本)
    for i in range(1, cols * rows + 1):
        # 设置索引,索引取值为0到训练集的长度
        sample_idx = torch.randint(len(train_data), size=(1,)).item()
        # 取出对应样本的图片和标签
        img, label = train_data[sample_idx]
        # 依次画于事先指定的九宫格图上
        figure.add_subplot(rows, cols, i)
        # 设置对应图片的标题
        plt.title(label_map[label])
        # 关掉坐标轴
        plt.axis("off")
        # 展示图片
        plt.imshow(img.squeeze(), cmap="gray")
    # 释放画布
    plt.show()


train_data, test_data = load_data()
show_data(train_data)


  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

out:

image-20220315095159288

上面用到了一个API,即torch.randint()

torch.randint(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor

  • 用于取随机整数,返回值为张量
  • low:int类型,表明要从分布中提取的最低整数
  • high:int类型,表明要从分布中提取的最高整数1
  • size:元组类型,表明输出张量的形状
  • dtype:返回值张量的数据类型
  • device:返回张量所需的设备
  • requires_grad:布尔类型,表明是否要对返回的张量自动求导。

如:

torch.randint(3, 5, (3,))
tensor([4, 3, 4])

   
  
  • 1
  • 2

意味生成一个一维的3元素向量,其中向量中的元素取值从3-5取。

2.2 自定义数据集

如果你不想使用框架自带的数据集,那么你可以自己定义一个数据集类。自定义Dataset类必须实现三个函数:__ init __ 、 __ len __ 、__ getitem __。其中图像部分存储于一个文件夹中,标签单独存储在CSV文件中。

在接下来的代码中,让我们看看如何创建一个自定义数据集。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

对于__ init __ 函数来说,包含加载图像、注释文件和两个转换的目录,在这里我们不做过多讲解,后面会详细介绍。

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

  
 
  • 1
  • 2
  • 3
  • 4
  • 5

对于__ len __ 函数,其功能是返回数据集中的样本数。

def __len__(self):
    return len(self.img_labels)

  
 
  • 1
  • 2

对于 __ getitem __,其功能是给定索引便能返回对应样本。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在自定义这一部分不用过多的去了解,用着用着就会了,就算不会代码也是通用,需要用的时候看一下复制一下,别搞得自己这么焦虑。

2.3 准备数据以进行数据加载器训练

在pytorch中,数据加载的核心实际上是torch.utils.data.DataLoader类,它支持对torch数据集的python可迭代,换而言之,DataLoader相当于你拿一个水盆,而dataset相当于泉水。DataLoader可以对小批量数据集进行处理,处理内容包括:

  • 地图样式和可迭代样式的数据集
  • 自定义数据集加载顺序
  • 多进程加载数据
  • 自动内存固定

其中地图样式数据集是指自定义数据集,而可迭代样式数据集指的是自带数据集。其他详情对于初学者来说很不友好,这里不做过多解释,你可以理解为这就是个科普知识。

我们来看一下这个API吧。

torch.utils.data.DataLoader(数据集, batch_size=1, shuffle=False)

  • 用于加载样本并且进行批处理
  • 数据集:要加载的数据集
  • batch_size:整数类型,表明每批要加载的样本数,默认为1
  • shuffle:布尔类型,表明是否要洗牌

我们利用上面的API来加载我们上面的Fashion_MNIST吧。

def load_batch_data():
    """数据集批处理加载器"""
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
    return train_dataloader, test_dataloader

  
 
  • 1
  • 2
  • 3
  • 4
  • 5

既然已经将样本导入加载器,那么我们如何从加载器中读取数据呢?我们可以根据需要循环访问数据集。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader


def load_data():
    """加载数据集"""
    # 1 训练数据集的加载
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    # 2 测试数据集的加载
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return train_data, test_data


def show_data(train_data):
    """数据集可视化"""
    label_map = {
        0: "T_Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    # 从训练集中随机抽出九张图(九个样本)
    for i in range(1, cols * rows + 1):
        # 设置索引,索引取值为0到训练集的长度
        sample_idx = torch.randint(len(train_data), size=(1,)).item()
        # 取出对应样本的图片和标签
        img, label = train_data[sample_idx]
        # 依次画于事先指定的九宫格图上
        figure.add_subplot(rows, cols, i)
        # 设置对应图片的标题
        plt.title(label_map[label])
        # 关掉坐标轴
        plt.axis("off")
        # 展示图片
        plt.imshow(img.squeeze(), cmap="gray")
    # 释放画布
    plt.show()


def load_batch_data():
    """数据集批处理加载器"""
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
    return train_dataloader, test_dataloader


def show_batch_data():
    """循环访问数据加载器"""
    train_dataloader, test_dataloader = load_batch_data()
    train_feature, train_labels = next(iter(train_dataloader))
    print(f"特征大小:{train_feature.size()}")
    print(f"标签大小:{train_labels.size()}")
    img = train_feature[0].squeeze()
    label = train_labels[0]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"label:{label}")


train_data, test_data = load_data()
# show_data(train_data)
show_batch_data()

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86

文章来源: blog.csdn.net,作者:ArimaMisaki,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/chengyuhaomei520/article/details/123496767

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。