pytorch学习---dataset

1、dataset是初入pytorch最重要的东西,在复现项目的时候,最需要改的就是数据集。

如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口。

2、所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。

对于图像分类任务,图像+分类

对于目标检测任务,图像+bbox、分类

对于超分辨率任务,低分辨率图像+超分辨率图像

对于文本分类任务,文本+分类

...

你只需定义好这个项目的x和y是什么。好了,上面都是扯闲篇,我们直接看dataset代码:

链接:https://blog.csdn.net/leviopku/article/details/99958182

这个链接非常的详细。

Pytorch用torch.utils.data.Dataset构建数据集,想要构建自己的数据集,则需继承Dataset类,并重写两个方法:

    • __len__ :定义整个数据集的长度。使用len(dataset)时会被调用。
    • __getitem__:用于索引数据集中的数据,比如dataset[i]

Dataset基类

PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。Dataset

类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。

看一下源码:

这里有一个getitem函数,getitem函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

list的制作方法通常是将图片的路径和标签信息存储在一个txt中,然后从txt中读取,所以总结一下基本流程:

制作存储了图片路径和标签信息的txt

将这些信息转化成list,list的每一个元素对应一个样本

通过getitem函数,读取数据和标签。

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再将)。

总而言之,要让PyTorch读取自己的数据集,只要两步:

制作图片数据的索引

构建Dataset子类

制作图片数据索引

非常简单,就是一些基本的操作,百度一下“”python如何保存txt文件“”就可以知道了。

然后一般来说,txt都是这样的格式

./Data/train/01.png 0

./Data/train/02.png 0

./Data/train/03.png 1

./Data/train/04.png 1

构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

from PIL import Image

from torch.utils.data import Dataset

class MyDataset(Datset):

def __init__(self,txt_path,transform=None,target_transform=None):

fh = open(txt_path,'r')

imgs = []

for line in fh:

line = line.rstrip()

words = line.split()

imgs.append((words[0].int(words[1])))

self.imgs = imgs

self.transform = transform

def __getitem__(self,index):

fn,label = self.imgs[index]

img=Image.open(fn).convert('RGB')

if self.transform is not None:

img = self.transform(img)

return img,label

def __len__(self):

return len(sefl.imgs)

Init

初始化中,我们从已经准备好的txt中获取了图片的路径和表亲啊,并且春初在self.imgs这意味着self.imgs是一个list就像上面我们讲的那样

初始化中,初始化了transform,transform是一个Compose类型,transform中包含一个list,list中定义了各种对图像进行的操作,比如随机剪裁,旋转反转等。

一个图片都进来之后,会经过数据处理(数据增强),最终变成另外一张图片,也就是模型的输入数据。但是PyTorch的数据增强是将原始图片进行处理,是不会生成新的图片。因此假如我们使用randomcrop这样的随机操作的时候,每次epoch输入进来的图片不会是一摸一样的,达到样本多样性的功能

getitem

self.imgs是一个list,每一个元素都是一个二元tuple,这很好理解(str1,str2)这样的

利用Image.open对图片进行读取,img类型为Image,mode=‘RGB’

用transform对图片进行处理,里面可能有什么 标准化(减均值除以标准差),随机剪裁什么的(后面会细说)

这样Mydataset就构建好了,剩下的操作就交给DataLoader,在DataLoader中,会触发Mydataset中的getitem函数读取一张图片的数据和标签,并将多个图片拼接成一个batch返回,每一个batch才是模型真正的输入。