关于Pytorch中dataset的迭代问题(这就是为什么我们要使用dataloader的原因之一)
【摘要】 目录
一、问题
二、思考
三、实验
四、解决方法
一、问题
博主在写采样器时将dataset的类对象赋值给data_source,然后准备对data_source取样,总是提示在__getitem__()函数提示越界。
二、思考
换言之,在对dataset对象进行迭代取样时其__len__()方法似乎失效了。。。
三、实验
博主做了如下实验,利用Py...
目录
一、问题
博主在写采样器时将dataset的类对象赋值给data_source,然后准备对data_source取样,总是提示在__getitem__()函数提示越界。
二、思考
换言之,在对dataset对象进行迭代取样时其__len__()方法似乎失效了。。。
三、实验
博主做了如下实验,利用Pytorch的FakeData类进行以上猜想的证实。FakeData定义如下:
-
import torch
-
import torch.utils.data as data
-
from .. import transforms
-
-
-
class FakeData(data.Dataset):
-
"""A fake dataset that returns randomly generated images and returns them as PIL images
-
-
Args:
-
size (int, optional): Size of the dataset. Default: 1000 images
-
image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
-
num_classes(int, optional): Number of classes in the datset. Default: 10
-
transform (callable, optional): A function/transform that takes in an PIL image
-
and returns a transformed version. E.g, ``transforms.RandomCrop``
-
target_transform (callable, optional): A function/transform that takes in the
-
target and transforms it.
-
random_offset (int): Offsets the index-based random seed used to
-
generate each image. Default: 0
-
-
"""
-
-
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
-
transform=None, target_transform=None, random_offset=0):
-
self.size = size
-
self.num_classes = num_classes
-
self.image_size = image_size
-
self.transform = transform
-
self.target_transform = target_transform
-
self.random_offset = random_offset
-
-
def __getitem__(self, index):
-
"""
-
Args:
-
index (int): Index
-
-
Returns:
-
tuple: (image, target) where target is class_index of the target class.
-
"""
-
# create random image that is consistent with the index id
-
rng_state = torch.get_rng_state()
-
torch.manual_seed(index + self.random_offset)
-
img = torch.randn(*self.image_size)
-
target = torch.Tensor(1).random_(0, self.num_classes)[0]
-
torch.set_rng_state(rng_state)
-
-
# convert to PIL Image
-
img = transforms.ToPILImage()(img)
-
if self.transform is not None:
-
img = self.transform(img)
-
if self.target_transform is not None:
-
target = self.target_transform(target)
-
-
return img, target
-
-
def __len__(self):
-
return self.size
-
-
def __repr__(self):
-
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
-
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
-
tmp = ' Transforms (if any): '
-
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
-
tmp = ' Target Transforms (if any): '
-
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
-
return fmt_str
代码如下:
-
from torch.utils.data import DataLoader
-
from torchvision import transforms
-
import torchvision
-
fake_dataset = torchvision.datasets.FakeData(
-
size=100,
-
image_size=(3, 256, 128),
-
num_classes=751,
-
transform=transforms.Compose([
-
transforms.Resize((384, 128),interpolation=3),
-
transforms.ToTensor(),
-
])
-
)
-
for i, batches in enumerate(fake_dataset):
-
print(i)
-
fake_loader = DataLoader(
-
fake_dataset,
-
batch_size=4,
-
shuffle=True
-
)
-
# for i, batches in enumerate(fake_loader):
-
# print(i)
结果是无限输入更新的i。。。
-
0
-
1
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
...
-
100
-
101
-
102
-
....
那么怎么办呢?
加上一个DataLoader作为dataset的迭代器就好了:
-
from torch.utils.data import DataLoader
-
from torchvision import transforms
-
import torchvision
-
fake_dataset = torchvision.datasets.FakeData(
-
size=100,
-
image_size=(3, 256, 128),
-
num_classes=751,
-
transform=transforms.Compose([
-
transforms.Resize((384, 128),interpolation=3),
-
transforms.ToTensor(),
-
])
-
)
-
# for i, batches in enumerate(fake_dataset):
-
# print(i)
-
fake_loader = DataLoader(
-
fake_dataset,
-
batch_size=4,
-
shuffle=True
-
)
-
for i, batches in enumerate(fake_loader):
-
print(i)
但是,这并不是解决问题的方法,因为我们需要先写采样器,再去构建Dataloader,本末不可倒置。
四、解决方法
可在dataset里新建一个变量items(例如一个dict),该变量的长度就是dataset中__len__()函数的返回值:
-
def __len__(self):
-
return len(self.items)
最后利用for i, item in self.data_source.items.items():访问新建的items变量,并获取每个变量的元素即可,这样一来__getitem__()函数也不用在采样器初始化时执行:
-
for i, item in self.data_source.items.items():
-
print(i)
是不是很方便呢~
五、总结
相信还有更多的方法解决这个问题。总之,灵活地使用Pytorch中各种元件,对编程实现实验细节有很大的帮助~
文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。
原文链接:nickhuang1996.blog.csdn.net/article/details/100103875
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)