pytorch 实现一个自定义的dataloader,每个batch都可以实现类别数量均衡

#!/usr/bin/python3
# _*_coding:utf-8 _*_


\'\'\'
自定义重写 dataset,实现类别均衡,体现为 每个batch都可以按照自己设定得比例来采样,且支持多进程和分布式
\'\'\'

from check_pkgs import *
import torch.distributed as dist

IMG_EXT = [\'.png\', \'.jpg\']


class MyClassBalanceDataset(Dataset):
    def __init__(self, root, transform=None):
        super(MyClassBalanceDataset, self).__init__()
        assert osp.exists(root)
        classes = sorted(d.name for d in os.scandir(root) if d.is_dir())
        classes_to_idx = {name: idx for idx, name in enumerate(classes)}
        idxs = list(classes_to_idx.values())
        class_idx_num = {i: 0 for i in idxs}
        class_idx_samples = {i: [] for i in idxs}
        samples = []
        start, end = 0, 0
        for cls in classes:
            _idx = classes_to_idx[cls]
            for f in [i for i in glob.glob(f\'{root}/**/*.*\', recursive=True) if osp.splitext(i)[-1] in IMG_EXT]:
                class_idx_num[_idx] += 1
                samples.append((f, _idx))
            end = len(samples)
            class_idx_samples[_idx] = [start, end]
            start = end

        print(f\'number of each category: {class_idx_num}\')
        print(f\'class_idx_samples: {class_idx_samples}\')

        self.samples = samples
        self.class_idx_samples = class_idx_samples
        self.transform = transform

    def __len__(self):
        total = len(self.samples)
        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
            total = math.ceil(total / num_replicas)
        return total

    def __getitem__(self, index):
        _path, _target = self.samples[index]
        ########### DEBUG ###########
        return index, _target
        ########### DEBUG ###########

        _sample = Image.open(_path).convert(\'RGB\')
        if self.transform is not None:
            _sample = self.transform(_sample)
        else:
            _sample = np.asarray(_sample)
        _target = torch.tensor(_target)
        return _sample, _target


# 自己实现一个batchsampler 采样器,精准控制每个batch里得类别数量
class MyBatchSampler(Sampler):
    def __init__(self, data_source, batch_size, class_weight):
        super(MyBatchSampler, self).__init__(data_source)
        self.data_source = data_source
        assert isinstance(class_weight, list)
        assert 1 - sum(class_weight) < 1e-5
        self.batch_size = batch_size

        _num = len(class_weight)
        number_in_batch = {i: 0 for i in range(_num)}
        for c in range(_num):
            number_in_batch[c] = math.floor(batch_size * class_weight[c])
        _remain_num = batch_size - sum(number_in_batch.values())
        number_in_batch[random.choice(range(_num))] += _remain_num
        self.number_in_batch = number_in_batch
        self.offset_per_class = {i: 0 for i in range(_num)}
        print(f\'setting number_in_batch: {number_in_batch}\')
        print(\'my sampler is inited.\')

        # 如果是分布式,需要重新分配采样比例,避免重复采样
        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
            t = self.data_source.class_idx_samples.items()
            for c, (start, end) in t:
                total = end - start
                num_samples = math.ceil(total / num_replicas)
                start_rank = rank * num_samples + start
                end_rank = start_rank + num_samples
                if end_rank > end: end_rank = end
                # update idx range
                self.data_source.class_idx_samples[c] = [start_rank, end_rank]

            print(\'using torch distributed mode.\')
            print(f\'current rank data sample setting: {self.data_source.class_idx_samples}\')

    def __iter__(self):
        print(\'======= start __iter__ =======\')
        batch = []
        i = 0
        while i < len(self):
            for c, num in self.number_in_batch.items():
                start, end = self.data_source.class_idx_samples[c]
                for _ in range(num):
                    idx = start + self.offset_per_class[c]
                    if idx >= end:
                        self.offset_per_class[c] = 0
                    idx = start + self.offset_per_class[c]
                    batch.append(idx)
                    self.offset_per_class[c] += 1

            assert len(batch) == self.batch_size
            # random.shuffle(batch)
            yield batch
            batch = []
            i += 1

    def __len__(self):
        return len(self.data_source) // self.batch_size


# 单卡版本
def test_1():
    root = \'G:/Project/DataSet/flower_photos/flower_photos\'
    assert osp.exists(root)
    batch_size = 32
    num_workers = 8
    transform = TF.Compose([
        TF.Resize((5, 5)),
        TF.Grayscale(1),
        TF.ToTensor(),
    ])
    n_class = 5
    clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1]
    ds = MyClassBalanceDataset(root, transform)
    _batchSampler = MyBatchSampler(ds, batch_size, clas_weight)
    data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler)
    print(f\'dataloader total: {len(data_loader)}\')
    for epoch in range(3):
        for step, (x, y) in enumerate(data_loader):
            # print(step)
            print(step, x)
            # print(\'batch hist:\', torch.histc(y.float(), n_class, 0, n_class - 1))


# 多卡分布式版本
def test_2():
    \'\'\'
    CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --mast_port=29734 \
    dataset_customized.py --distributed=1
    :return:
    \'\'\'
    parser = argparse.ArgumentParser()
    parser.add_argument(\'--local_rank\', type=int, default=0, help=\'node rank for distributed parallel\')
    parser.add_argument(\'--distributed\', type=int, default=0, help=\'distributed mode\')
    args = parser.parse_args()
    assert torch.distributed.is_nccl_available()
    torch.cuda.set_device(args.local_rank)
    device_num = torch.cuda.device_count()
    distributed_mode = device_num >= 2 and args.distributed
    if distributed_mode:
        dist.init_process_group(\'nccl\', world_size=device_num, rank=args.local_rank)
        rank = dist.get_rank()
        num_rep = dist.get_world_size()
        print(rank, num_rep)
    print(\'torch distributed work is inited.\')

    root = \'G:/Project/DataSet/flower_photos/flower_photos\'
    assert osp.exists(root)
    batch_size = 32
    num_workers = 8
    transform = TF.Compose([
        TF.Resize((5, 5)),
        TF.Grayscale(1),
        TF.ToTensor(),
    ])
    n_class = 5
    clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1]
    ds = MyClassBalanceDataset(root, transform)
    _batchSampler = MyBatchSampler(ds, batch_size, clas_weight)
    data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler)
    print(f\'dataloader total: {len(data_loader)}\')
    for epoch in range(3):
        for step, (x, y) in enumerate(data_loader):
            print(step, x)
            print(\'batch hist:\', torch.histc(y.float(), n_class, 0, n_class - 1))


if __name__ == \'__main__\':
    test_1()