pytorch-使用torch.utils.data.DataLoader, __iter__, ___getitem__来模拟batch数据的处理过程

使用__iter__, __getitem__来模拟数据处理部分

import torch.utils.data
class Model():
    def __init__(self, animal_list):
        self.animal_list = animal_list
    # 根据迭代batch_size进行返回
    def __getitem__(self, index):
        root = {'A': self.animal_list[index], 'B': 1}
        return root

    def __len__(self):
        return len(self.animal_list)

class Animal:
    def __init__(self, animal_list):
        self.animals_name = animal_list
        self.m = Model(self.animals_name)
        self.model = torch.utils.data.DataLoader(
            self.m, # 构造两个self.m的输出结果 
            batch_size=2,
            shuffle=True # idx 是随机值 
        )
    def __iter__(self):
        for i, data in enumerate(self.model):

            yield data


animals = Animal(['dog', 'cat', 'fish'])

for i, animal in enumerate(animals):
    print(animal)