PyTorch源码解读之torch.utils.data.DataLoader,转

https://blog.csdn.net/u014380165/article/details/79058479

写得特别好!最近正好在学习pytorch,学习一下!

https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

DataLoader类源码如下。先看看__init__中的几个重要的输入:1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。2、batch_size,根据具体情况设置即可。3、shuffle,一般在训练数据中会采用。4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

class DataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id as input, after seeding and before data loading. (default: None) """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if self.num_workers < 0: raise ValueError('num_workers cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)

DataLoaderIter类源码如下。self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多进程管理包,而threading则是Python中的多线程管理包,二者很大一部分的接口用法类似。还是照例先看看__init__,前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。if self.num_workers语句是针对多进程或单进程的情况进行初始化,如果不是设置为多进程读取数据,那么就不需要这些初始化操作,后面会介绍单进程数据读取。在if语句中通过multiprocessing.SimpleQueue()类创建了一个简单的队列对象。multiprocessing.Process类就是构造进程的类,这里根据设定的进程数来启动,然后赋值给self.workers。接下来的一个for循环就通过调用start方法依次启动self.workers中的进程。接下来关于self.pin_memory的判断语句,该判断语句内部主要是实现了多线程操作。self.pin_memory的含义在前面已经介绍过了,当为True的时候,就会把数据拷到CUDA中。self.data_queue = queue.Queue()是通过Python的queue模块初始化得到一个先进先出的队列(queue模块也可以初始化得到先进后出的队列,需要用queue.LifoQueue()初始化),queue模块主要应用在多线程读取数据中。在threading.Thread的args参数中,第一个参数in_data就是一个进程的数据,一个进程中不同线程的数据也是通过队列来维护的,这里采用的是Python的queue模块来初始化得到一个队列:queue.Queue()。初始化结束后,就会调用__next__方法,接下来介绍。

class DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queue = multiprocessing.SimpleQueue() self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_()[0] self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, torch.cuda.current_device())) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()

DataLoaderIter类的__next__方法如下,包含3个if语句和1个while语句。

 def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch)

pin_memory_batch函数不是定义在DataLoader类或DataLoaderIter类中。该函数主要是对batch中的Tensor执行batch.pin_memory()操作,这里的很多条件语句只是用来判断batch的类型,假如batch是一个列表,列表中的每个值是Tensor,那么就会执行 elif isinstance(batch, collections.Sequence):这个条件,从而遍历该列表中的每个Tensor,然后执行第一个条件语句的内容: return batch.pin_memory()

def pin_memory_batch(batch): if torch.is_tensor(batch): return batch.pin_memory() elif isinstance(batch, string_classes): return batch elif isinstance(batch, collections.Mapping): return {k: pin_memory_batch(sample) for k, sample in batch.items()} elif isinstance(batch, collections.Sequence): return [pin_memory_batch(sample) for sample in batch] else: return batch
iter() 与 next()的调用方法
dataiter = iter(trainloader) # 每次迭代取的是一个batch images, labels = dataiter.next() # 如果batch_size为4,则取出来的images是4×c×h×w的tensor,labels是1×4的向量

DataloaderIter类的_get_batch方法。主要根据是否设置了超时时间来操作,如果超过指定的超时时间后没有从队列中读到数据就报错,如果不设置超时时间且一致没有从队列中读到数据,那么就会一直卡着且不报错,这部分是PyTorch后来修的一个bug。

 def _get_batch(self): if self.timeout > 0: try: return self.data_queue.get(True, self.timeout) except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) else: return self.data_queue.get()

DataLoaderIter类的_process_next_batch方法。首先对self.rcvd_idx进行加一,也就是更新下下一个要读取的batch数据的index。然后调用_put_indices()方法获取下一个batch的每个数据的index。

 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch

DataLoaderIter类的_put_indices方法。该方法主要实现从self.sample_iter中读取下一个batch数据中每个数据的index:indices = next(self.sample_iter, None),注意这里的index和前面idx是不一样的,这里的index是一个batch中每个数据的index,idx是一个batch的index;然后将读取到的index通过调用queue对象的put方法压到队列self.index_queue中:self.index_queue.put((self.send_idx, indices))

def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queue.put((self.send_idx, indices)) self.batches_outstanding += 1 self.send_idx += 1