Pytorch使用篇之DataLoader使用

举报
bzp123 发表于 2020/08/29 09:33:33 2020/08/29
【摘要】 ​DataLoader是Pytorch中用来处理模型输入数据的一个工具类。借助DataLoader,可以方便地对输入数据进行操作。例如,我们可以将数据按照设定的batch大小进行自动划分,不用手动写循环;同时我们可以设定shuffle参数,在每个迭代周期里,对数据进行打乱。

一、简述


    DataLoader是Pytorch中用来处理模型输入数据的一个工具类。借助DataLoader,可以方便地对输入数据进行操作。例如,我们可以将数据按照设定的batch大小进行自动划分,不用手动写循环;同时我们可以设定shuffle参数,在每个迭代周期里,对数据进行打乱。


    它的基本使用流程为:

    1. 首先自定义DataLoader类,需要包含__getitem__、__len__方法;

    2. 将原始数据加载到DataLoader类中,并设置好shuffle等参数;

    2. 再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。


二、代码实战篇

下面给出自定义的一个DataLoader类:

    from torch.utils.data.dataset import Dataset

    from torch.utils.data import DataLoader

    # 自定义一个Dataset

    class MyCustomData(Dataset):

    def __init__(self, x_data, y_data):

           # 注意:输入进来的x_datay_data都是np.array为了稳妥起见,把它们的dtype均转换为float

                x_data = x_data.astype(float)

                y_data = y_data astype(float)

                self.x_data = torch.from_numpy(x_data)

                self.y_data = torch.from_numpy(y_data)

                self.len = self.x_data.shape[0]

     

        def __getitem__(self, index):

            return self.x_data[index], self.y_data[index]

     

        def __len__(self):

            return self.len

 

在实际使用时:

train_data = MyCustomData(X, Y) 

假设Xshape[21740, 6, 5],是一个numpy.array; Yshape[21740, 2, 1] ,是一个numpy.array,即表示共有21740个样本;每个样本的timestep长度为6,每个timestep的输入向量长度为5label是一个2维向量


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。