【Pytorch】dataloader使用教程

# -*- coding: utf-8 -*-
"""
Created on Mon Aug  3 23:30:39 2020

@author: Administrator
"""

import torch                       # 导入模块
import torch.utils.data as Data

BATCH_SIZE = 8                     # 每一批的数据量

x=torch.linspace(1,10,10)          # 定义X为 1 到 10 等距离大小的数
y=torch.linspace(10,1,10)

# 转换成torch能识别的Dataset
# 这个可以自定义DataSet:https://www.cnblogs.com/douzujun/p/13429912.html
torch_dataset = Data.TensorDataset(x, y) # 将数据放入 torch_dataset

loader=Data.DataLoader(
        dataset=torch_dataset,           # 将数据放入loader
        batch_size=BATCH_SIZE,           # 每个数据段大小为  BATCH_SIZE=5
        shuffle=True ,                   # 是否打乱数据的排布
        num_workers=0                    # 使用多进程加载的进程数,0代表不使用多进程
        )

for epoch in range(3):
    
    for step, (batch_x,batch_y) in enumerate(loader):
        
        print('epoch',epoch,'|step:',step," | batch_x",batch_x.numpy(),

              '|batch_y:',batch_y.numpy())
epoch 0 |step: 0  | batch_x [ 7.  3.  1.  8. 10.  9.  5.  4.] |batch_y: [ 4.  8. 10.  3.  1.  2.  6.  7.]
epoch 0 |step: 1  | batch_x [2. 6.] |batch_y: [9. 5.]
epoch 1 |step: 0  | batch_x [ 6.  7.  5.  4.  1. 10.  2.  9.] |batch_y: [ 5.  4.  6.  7. 10.  1.  9.  2.]
epoch 1 |step: 1  | batch_x [3. 8.] |batch_y: [8. 3.]
epoch 2 |step: 0  | batch_x [ 4.  5.  7.  1.  6.  9. 10.  3.] |batch_y: [ 7.  6.  4. 10.  5.  2.  1.  8.]
epoch 2 |step: 1  | batch_x [8. 2.] |batch_y: [3. 9.]

DataLoader的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           num_workers=0, collate_fn=default_collate, pin_memory=False,
           drop_last=False)
  • dataset:加载的数据集(Dataset对象)

  • batch_size:batch size

  • shuffle::是否将数据打乱

  • sampler: 样本抽样,后续会详细介绍

  • num_workers:使用多进程加载的进程数,0代表不使用多进程

  • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃