Pytorch使用篇之DataLoader使用
一、简述
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。借助DataLoader,可以方便地对输入数据进行操作。例如,我们可以将数据按照设定的batch大小进行自动划分,不用手动写循环;同时我们可以设定shuffle参数,在每个迭代周期里,对数据进行打乱。
它的基本使用流程为:
1. 首先自定义DataLoader类,需要包含__getitem__、__len__方法;
2. 将原始数据加载到DataLoader类中,并设置好shuffle等参数;
2. 再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。
二、代码实战篇
下面给出自定义的一个DataLoader类:
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
# 自定义一个Dataset
class MyCustomData(Dataset):
def __init__(self, x_data, y_data):
# 注意:输入进来的x_data和y_data都是np.array,为了稳妥起见,把它们的dtype均转换为float
x_data = x_data.astype(float)
y_data = y_data astype(float)
self.x_data = torch.from_numpy(x_data)
self.y_data = torch.from_numpy(y_data)
self.len = self.x_data.shape[0]
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
在实际使用时:
train_data = MyCustomData(X, Y)
假设X的shape:[21740, 6, 5],是一个numpy.array; Y的shape:[21740, 2, 1] ,是一个numpy.array,即表示共有21740个样本;每个样本的timestep长度为6,每个timestep的输入向量长度为5;label是一个2维向量
- 点赞
- 收藏
- 关注作者
评论(0)