pytorch的dataset与dataloader解析

整理一下pytorch获取的流程:

  1. 创建Dataset对象
  2. 创建DataLoader对象,装载有dataset对象
  3. 循环DataLoader对象,DataLoader.__iter__返回的是DataLoaderIter对象
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for data in dataloader:
        ....

根据源码分析:torch.utils.data

1 - Dataset:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

Dataset这是一个抽象类,不能实例化,需要重写类方法,关键点有两个:

  • __getitem__ 这个很重要,规定了如何读数据,比如常用的transform
  • __len__ 这个就是返回数据集的长度,比如:return len(self.data)

2 - DataLoader:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

先看一下主要参数:

  • dataset:就是 torch.utils.data.Dataset 类的实例。也就是说为了使用 DataLoader 类,需要先定义一个 torch.utils.data.Dataset 类的实例。
  • batch_size:每一个批次需要加载的训练样本个数。
  • shuffle:如果设置为 True 表示训练样本数据会被随机打乱,默认值为 False。一般会设置为 True 。
  • sampler:自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle 必须为 False 。从源码中可以看到,如果指定了该参数,同时 shuffle 设定为 True,DataLoader 的 __init__ 函数就会抛出一个异常 。
  • batch_sampler:与 sampler 类似,但是一次只返回一个 batch 的 indices(索引),需要注意的是,一旦指定了这个参数,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源码中同样做了限制。
  • num_workers:表示会使用多少个线程来加载训练数据;默认值为 0,表示数据加载直接在主线程中进行。
  • collate_fn:对每一个 batch 的数据做一些你想要的操作。一个例子,https://zhuanlan.zhihu.com/p/346332974
  • pin_memory:把数据转移到和 GPU 相关联的 CPU 内存,加速 GPU 载入数据的速度。
  • drop_last:比如你的batch_size设置为 32,而一个 epoch 只有 100 个样本;如果设置为 True,那么训练的时候后面的 4 个就被扔掉了。如果为 False(默认),那么会继续正常执行,只是最后的 batch_size 会小一点。
  • timeout:加载一个 batch 数据的超时时间。
  • worker_init_fn:指定每个数据加载线程的入口函数。

源码分析:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
                    sampler = RandomSampler(dataset)
                else:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
                    sampler = SequentialSampler(dataset)
            # Sampler 是个迭代器,一次之只返回一个 索引
            # BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler) 

可以发现__iter__返回的是DataLoaderIter

3 - DataLoaderIter

先看init初始化:

if self.num_workers > 0:
    self.worker_init_fn = loader.worker_init_fn
# 定义了workers相同数量个Queue并放置在index_queues这个list中, # 这些Queue与worker一一对应,用来给worker传递“工作内容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用于下一个工作的workre序号,主进程轮询使用不同workers self.worker_queue_idx = 0
# 各个workre将自己所取得的数据传递给wokrker_result_queue,供主进程fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 记录当前时刻分配了多少个任务(可能有处于等待状态的任务) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 发送出去数据的编号 self.send_idx = 0 # 接受到数据的编号 self.rcvd_idx = 0 # 缓存区 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相应的进程,目标函数为_worker_loop # 参数:dataset(用于数据读取),index_queues[i]为worker对应的index_queue # 以及用于输出的queue # 此处主要用于数据读取后的pin_memory操作,不影响多进程主逻辑,暂不展开 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 将父进程设置为守护进程,保证父进程结束后,worker进程也结束,必须设置在start之前 w.start() # 下面是一些系统信号处理逻辑,对这方面我还不太熟悉就不介绍了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化后生成2*num_workers数量个prefetch的数据,使dataloader提前工作,提升整体效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()

init过程有两个函数,一个是worker_loop,另个是put_indices

a. 先看worker_loop:

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True

    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
    # module's handlers are executed after Python returns from C low-level
    # handlers, likely when the same fatal signal happened again already.
    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    random.seed(seed)
    torch.manual_seed(seed)

    if init_fn is not None:
        init_fn(worker_id)
    
    # 父进程状态监测
    watchdog = ManagerWatchdog()
    
    # 死循环查询是否有任务传进来
    while True:
        try:
            # 从index_queue获取相应数据
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            if watchdog.is_alive():
                continue
            else:
                break
        if r is None:
            break
        idx, batch_indices = r
        try:
            # 获得以后for循环进行读取数据读取,此处和单进程的工作原理是一样的
            # 因此时间花费和batchsize数量呈线性关系
            samples = collate_fn([dataset[i] for i in batch_indices])
            # 经过collate_fn后变成torch.Tensor
        except Exception:
            # 异常处理
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            # 通过data_queue传回处理好的batch数据
            data_queue.put((idx, samples))
            # 显示删除中间变量,降低内存消耗
            del samples

这里就是不停地轮询,从index_queues队列里获得索引,然后通过collate_fn函数和索引获取tensor,然后塞入data_queue

b. 再看put_indices

def _put_indices(self):
    assert self.batches_outstanding < 2 * self.num_workers
    # 默认设定是只允许分配2*num_workers个任务,保证内存等资源不被耗尽
    indices = next(self.sample_iter, None)
    # 从sample_iter中拿到dataset中下一轮次的索引,用于fetch数据
    if indices is None:
        return
    self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
    # 轮询选择worker,找到其对应的队列,向其中发送工作内容(数据编号,数据索引)
    self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
    # worker_queue_idx自增
    self.batches_outstanding += 1
    # 任务分配数+1
    self.send_idx += 1
    # 已发送任务总数+1(下批数据编号) 

这个就是把索引塞进队列index_queues

以上就是init,当for循环时,会调用next:

c. __next__返回一个batch

def __next__(self):
    if self.num_workers == 0:  # same-process loading  (主进程阻塞式读取数据)
        indices = next(self.sample_iter)  # may raise StopIteration
        batch = self.collate_fn([self.dataset[i] for i in indices])
        if self.pin_memory:
            batch = pin_memory_batch(batch)
        return batch
    
    # check if the next sample has already been generated
    # 先查看数据是否在缓存dict中
    if self.rcvd_idx in self.reorder_dict:
        batch = self.reorder_dict.pop(self.rcvd_idx)
        return self._process_next_batch(batch)
    # 异常处理
    if self.batches_outstanding == 0:
        self._shutdown_workers()
        raise StopIteration
    while True:
        assert (not self.shutdown and self.batches_outstanding > 0)
        # 阻塞式的从data_queue里面获取处理好的批数据
        idx, batch = self._get_batch() 
        # 任务数减一
        self.batches_outstanding -= 1
        # 这一步可能会造成的周期阻塞现象
        # 每次获取data以后,要校验和rcvd_idx是否一致
        # 若不一致,则先把获取到的数据放到reorder_dict这个缓存dict中,继续死循环
        # 直到获取到相应的idx编号于rcvd_idx可以对应上,并将数据返回
        if idx != self.rcvd_idx:
            # store out-of-order samples
            self.reorder_dict[idx] = batch
            continue
        return self._process_next_batch(batch)

__next__里的while True,要从data_queue里面读到的数据idx和rcvd_idx一致才将数据返回。因此可能会存在如下这种情况:

假设num_workers=8,现在发送了8个数据给相应的worker,此时send_idx=8,rcvd_idx=0。过了一段时间以后,{1,2,3,5,6,7}进程数据准备完毕,此时主进程从data_queue读取到相关的数据,但由于和rcvd_idx不匹配,只能将其放在缓存里。直到send_idx=0数据准备齐以后,才能将数据返回出去,随后从缓存中弹出2,3的数据,之后又阻塞等待idx=4的数据。即输出的数据必须保持顺序性!因此在worker变多,出现这种逆序现象可能性会更大,这种现象也会出现在非num_workrers次迭代,只要相应的rcvd_idx没有得到相关数据,则主进程就会一直等待。

d. process_next_batch

def _process_next_batch(self, batch):
    # 序号对上以后,rcvd_idx自加1
    self.rcvd_idx += 1
    # 添加一个fetchdata任务给worker
    self._put_indices()
    if isinstance(batch, ExceptionWrapper):
        raise batch.exc_type(batch.exc_msg)
    return batch

  

这个函数注意的是,只有在__next__中,idx == self.rcvd_idx时才会调用,也就是可能出现多个worker已经准备好了,但是只能放在缓存区,并且无法向index_queues塞入索引,使worker无法保持活跃状态。

最后对于for循环从dataloader获取data总体流程:

for epoch in range(num_epoches):
    for data in dataloader:

对于这个for,其实就是调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter,如果是num_worker>0,init里就会创建多线程,并且有两个队列,一个是存放dataset的索引index_queues,一个是从index_queues里拿到索引,调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch,放到data_queue队列里,反复调用DataLoaderIter 的__next__,从data_queue中获取batch。

参考:

Pytorch数据读取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236

PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188

PyTorch36.DataLoader源代码剖析 https://zhuanlan.zhihu.com/p/169497395

PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 https://zhuanlan.zhihu.com/p/76893455