深度之眼PyTorch训练营第二期 ---5、Dataloader与Dataset 以及 transforms与normalize

一、人民币二分类

描述:输入人民币,通过模型判定类别并输出。

  • 数据:四个子模块

    数据收集 -> img,label 原始数据和标签

    数据划分 -> train训练集 valid验证集 test测试集

    数据读取 -> DataLoader ->(1)Sampler(生成index) (2)Dataset(读取Img,Label)

    数据预处理 -> transforms

1、DataLoader

  •   torch.utils.data.DataLoader 功能:构建可迭代的数据装载器
    •   dataset:Dataset类,决定数据从哪里读取及如何读取
    • batchsize:批大小
    • num_works:是否多进程读取数据
    • shuffle:每个epoch是否乱序
    • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
  • Epoch、Iteration、Batchsize关系

    Epoch:所有训练样本都已输入到模型中,称为一个epoch

    Iteration:一批样本输入到模型中,称之为一个Iteration

    Batchsize:批大小,决定一个Epoch有多少个iteration

    例如:样本总数:80 batchsize:8

      1 epoch = 10 iteration 一次iteration输入8个样本,所以一次的epoch=8

        样本总数:87 batchsize:8

      if drop_last = true 1 epoch = 10 iteration

else drop_last = false 1 epoch = 11 iteration

2、Dataset

  • torch.utils.data.Dataset 功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写
    •   __getitem__() :接收一个索引,返回一个样本及标签

3、数据读取:

思考:读哪些数据?从哪里读数据?怎么读数据?

二、Dataloader与Dataset

三、transform运行机制

torchvision:计算机视觉工具包

  • torchvision.transforms:常用的图像预处理方法
  • torchvision.datasets:常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet
  • torchvision.model:常用的模型预训练,AlexNet,VGG,ResNet,GoogLeNet等

1、torchvision.transforms --- 提高泛化能力

  • l 数据中心化
  • l 数据标准化
  • l 缩放
  • l 剪裁
  • l 旋转
  • l 翻转
  • l 填充
  • l 噪声添加
  • l 灰度变换
  • l 线性变换
  • l 仿射变换
  • l 亮度、饱和度及对比度变换

四、数据标准化—transforms.normalize

transforms.normalize 功能:逐channel的对图像进行标准化

    output = (input - mean)/ std;

    •   mean:各通道的均值
    •   std:各通道的标准差
    •   inplace:是否原地操作