pytorch之DataLoader,函数

在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能。

DataLoader的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, 
drop_last=False)

dataset:加载的数据集(Dataset对象)

batch_size:batch size

shuffle::是否将数据打乱

sampler: 样本抽样,后续会详细介绍

num_workers:使用多进程加载的进程数,0代表不使用多进程

collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

def main():
    import visdom
    import time

    viz = visdom.Visdom()

    db = Pokemon('pokeman', 224, 'train')

    x,y = next(iter(db))   ##
    print('sample:',x.shape,y.shape,y)

    viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))

    loader = DataLoader(db,batch_size=32,shuffle=True)

    for x,y in loader:  #为了得一个一个的数据集形式的数据每一组32张
        viz.images(db.denormalize(x),nrow=8,win='batch',opts = dict(title = 'batch'))
        viz.text(str(y.numpy()),win = 'label',opts=dict(title='batch-y'))

        time.sleep(10)

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在_ getitem _函数中将出现异常,此时最好的解决方案即是将出错的样本剔除