关于Pytorch中dataset的迭代问题(这就是为什么我们要使用dataloader的原因之一)

举报
悲恋花丶无心之人 发表于 2021/02/03 02:36:24 2021/02/03
【摘要】 目录 一、问题 二、思考 三、实验 四、解决方法 一、问题 博主在写采样器时将dataset的类对象赋值给data_source,然后准备对data_source取样,总是提示在__getitem__()函数提示越界。 二、思考 换言之,在对dataset对象进行迭代取样时其__len__()方法似乎失效了。。。 三、实验 博主做了如下实验,利用Py...

目录

一、问题

二、思考

三、实验

四、解决方法


一、问题

博主在写采样器时将dataset的类对象赋值给data_source,然后准备对data_source取样,总是提示在__getitem__()函数提示越界。


二、思考

换言之,在对dataset对象进行迭代取样时其__len__()方法似乎失效了。。。


三、实验

博主做了如下实验,利用Pytorch的FakeData类进行以上猜想的证实。FakeData定义如下:


  
  1. import torch
  2. import torch.utils.data as data
  3. from .. import transforms
  4. class FakeData(data.Dataset):
  5. """A fake dataset that returns randomly generated images and returns them as PIL images
  6. Args:
  7. size (int, optional): Size of the dataset. Default: 1000 images
  8. image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
  9. num_classes(int, optional): Number of classes in the datset. Default: 10
  10. transform (callable, optional): A function/transform that takes in an PIL image
  11. and returns a transformed version. E.g, ``transforms.RandomCrop``
  12. target_transform (callable, optional): A function/transform that takes in the
  13. target and transforms it.
  14. random_offset (int): Offsets the index-based random seed used to
  15. generate each image. Default: 0
  16. """
  17. def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
  18. transform=None, target_transform=None, random_offset=0):
  19. self.size = size
  20. self.num_classes = num_classes
  21. self.image_size = image_size
  22. self.transform = transform
  23. self.target_transform = target_transform
  24. self.random_offset = random_offset
  25. def __getitem__(self, index):
  26. """
  27. Args:
  28. index (int): Index
  29. Returns:
  30. tuple: (image, target) where target is class_index of the target class.
  31. """
  32. # create random image that is consistent with the index id
  33. rng_state = torch.get_rng_state()
  34. torch.manual_seed(index + self.random_offset)
  35. img = torch.randn(*self.image_size)
  36. target = torch.Tensor(1).random_(0, self.num_classes)[0]
  37. torch.set_rng_state(rng_state)
  38. # convert to PIL Image
  39. img = transforms.ToPILImage()(img)
  40. if self.transform is not None:
  41. img = self.transform(img)
  42. if self.target_transform is not None:
  43. target = self.target_transform(target)
  44. return img, target
  45. def __len__(self):
  46. return self.size
  47. def __repr__(self):
  48. fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  49. fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  50. tmp = ' Transforms (if any): '
  51. fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  52. tmp = ' Target Transforms (if any): '
  53. fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  54. return fmt_str

代码如下:


  
  1. from torch.utils.data import DataLoader
  2. from torchvision import transforms
  3. import torchvision
  4. fake_dataset = torchvision.datasets.FakeData(
  5. size=100,
  6. image_size=(3, 256, 128),
  7. num_classes=751,
  8. transform=transforms.Compose([
  9. transforms.Resize((384, 128),interpolation=3),
  10. transforms.ToTensor(),
  11. ])
  12. )
  13. for i, batches in enumerate(fake_dataset):
  14. print(i)
  15. fake_loader = DataLoader(
  16. fake_dataset,
  17. batch_size=4,
  18. shuffle=True
  19. )
  20. # for i, batches in enumerate(fake_loader):
  21. # print(i)

结果是无限输入更新的i。。。


  
  1. 0
  2. 1
  3. 2
  4. 3
  5. 4
  6. 5
  7. 6
  8. 7
  9. 8
  10. 9
  11. ...
  12. 100
  13. 101
  14. 102
  15. ....

那么怎么办呢?

加上一个DataLoader作为dataset的迭代器就好了:


  
  1. from torch.utils.data import DataLoader
  2. from torchvision import transforms
  3. import torchvision
  4. fake_dataset = torchvision.datasets.FakeData(
  5. size=100,
  6. image_size=(3, 256, 128),
  7. num_classes=751,
  8. transform=transforms.Compose([
  9. transforms.Resize((384, 128),interpolation=3),
  10. transforms.ToTensor(),
  11. ])
  12. )
  13. # for i, batches in enumerate(fake_dataset):
  14. # print(i)
  15. fake_loader = DataLoader(
  16. fake_dataset,
  17. batch_size=4,
  18. shuffle=True
  19. )
  20. for i, batches in enumerate(fake_loader):
  21. print(i)

但是,这并不是解决问题的方法,因为我们需要先写采样器再去构建Dataloader,本末不可倒置。


四、解决方法

可在dataset里新建一个变量items(例如一个dict),该变量的长度就是dataset中__len__()函数的返回值:


  
  1. def __len__(self):
  2. return len(self.items)

最后利用for i, item in self.data_source.items.items():访问新建的items变量,并获取每个变量的元素即可,这样一来__getitem__()函数也不用在采样器初始化时执行:


  
  1. for i, item in self.data_source.items.items():
  2. print(i)

是不是很方便呢~


五、总结

相信还有更多的方法解决这个问题。总之,灵活地使用Pytorch中各种元件,对编程实现实验细节有很大的帮助~

文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。

原文链接:nickhuang1996.blog.csdn.net/article/details/100103875

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。