#!/usr/bin/python3
# _*_coding:utf-8 _*_
\'\'\'
自定义重写 dataset,实现类别均衡,体现为 每个batch都可以按照自己设定得比例来采样,且支持多进程和分布式
\'\'\'
from check_pkgs import *
import torch.distributed as dist
IMG_EXT = [\'.png\', \'.jpg\']
class MyClassBalanceDataset(Dataset):
def __init__(self, root, transform=None):
super(MyClassBalanceDataset, self).__init__()
assert osp.exists(root)
classes = sorted(d.name for d in os.scandir(root) if d.is_dir())
classes_to_idx = {name: idx for idx, name in enumerate(classes)}
idxs = list(classes_to_idx.values())
class_idx_num = {i: 0 for i in idxs}
class_idx_samples = {i: [] for i in idxs}
samples = []
start, end = 0, 0
for cls in classes:
_idx = classes_to_idx[cls]
for f in [i for i in glob.glob(f\'{root}/**/*.*\', recursive=True) if osp.splitext(i)[-1] in IMG_EXT]:
class_idx_num[_idx] += 1
samples.append((f, _idx))
end = len(samples)
class_idx_samples[_idx] = [start, end]
start = end
print(f\'number of each category: {class_idx_num}\')
print(f\'class_idx_samples: {class_idx_samples}\')
self.samples = samples
self.class_idx_samples = class_idx_samples
self.transform = transform
def __len__(self):
total = len(self.samples)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
num_replicas = dist.get_world_size()
total = math.ceil(total / num_replicas)
return total
def __getitem__(self, index):
_path, _target = self.samples[index]
########### DEBUG ###########
return index, _target
########### DEBUG ###########
_sample = Image.open(_path).convert(\'RGB\')
if self.transform is not None:
_sample = self.transform(_sample)
else:
_sample = np.asarray(_sample)
_target = torch.tensor(_target)
return _sample, _target
# 自己实现一个batchsampler 采样器,精准控制每个batch里得类别数量
class MyBatchSampler(Sampler):
def __init__(self, data_source, batch_size, class_weight):
super(MyBatchSampler, self).__init__(data_source)
self.data_source = data_source
assert isinstance(class_weight, list)
assert 1 - sum(class_weight) < 1e-5
self.batch_size = batch_size
_num = len(class_weight)
number_in_batch = {i: 0 for i in range(_num)}
for c in range(_num):
number_in_batch[c] = math.floor(batch_size * class_weight[c])
_remain_num = batch_size - sum(number_in_batch.values())
number_in_batch[random.choice(range(_num))] += _remain_num
self.number_in_batch = number_in_batch
self.offset_per_class = {i: 0 for i in range(_num)}
print(f\'setting number_in_batch: {number_in_batch}\')
print(\'my sampler is inited.\')
# 如果是分布式,需要重新分配采样比例,避免重复采样
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
num_replicas = dist.get_world_size()
t = self.data_source.class_idx_samples.items()
for c, (start, end) in t:
total = end - start
num_samples = math.ceil(total / num_replicas)
start_rank = rank * num_samples + start
end_rank = start_rank + num_samples
if end_rank > end: end_rank = end
# update idx range
self.data_source.class_idx_samples[c] = [start_rank, end_rank]
print(\'using torch distributed mode.\')
print(f\'current rank data sample setting: {self.data_source.class_idx_samples}\')
def __iter__(self):
print(\'======= start __iter__ =======\')
batch = []
i = 0
while i < len(self):
for c, num in self.number_in_batch.items():
start, end = self.data_source.class_idx_samples[c]
for _ in range(num):
idx = start + self.offset_per_class[c]
if idx >= end:
self.offset_per_class[c] = 0
idx = start + self.offset_per_class[c]
batch.append(idx)
self.offset_per_class[c] += 1
assert len(batch) == self.batch_size
# random.shuffle(batch)
yield batch
batch = []
i += 1
def __len__(self):
return len(self.data_source) // self.batch_size
# 单卡版本
def test_1():
root = \'G:/Project/DataSet/flower_photos/flower_photos\'
assert osp.exists(root)
batch_size = 32
num_workers = 8
transform = TF.Compose([
TF.Resize((5, 5)),
TF.Grayscale(1),
TF.ToTensor(),
])
n_class = 5
clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1]
ds = MyClassBalanceDataset(root, transform)
_batchSampler = MyBatchSampler(ds, batch_size, clas_weight)
data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler)
print(f\'dataloader total: {len(data_loader)}\')
for epoch in range(3):
for step, (x, y) in enumerate(data_loader):
# print(step)
print(step, x)
# print(\'batch hist:\', torch.histc(y.float(), n_class, 0, n_class - 1))
# 多卡分布式版本
def test_2():
\'\'\'
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --mast_port=29734 \
dataset_customized.py --distributed=1
:return:
\'\'\'
parser = argparse.ArgumentParser()
parser.add_argument(\'--local_rank\', type=int, default=0, help=\'node rank for distributed parallel\')
parser.add_argument(\'--distributed\', type=int, default=0, help=\'distributed mode\')
args = parser.parse_args()
assert torch.distributed.is_nccl_available()
torch.cuda.set_device(args.local_rank)
device_num = torch.cuda.device_count()
distributed_mode = device_num >= 2 and args.distributed
if distributed_mode:
dist.init_process_group(\'nccl\', world_size=device_num, rank=args.local_rank)
rank = dist.get_rank()
num_rep = dist.get_world_size()
print(rank, num_rep)
print(\'torch distributed work is inited.\')
root = \'G:/Project/DataSet/flower_photos/flower_photos\'
assert osp.exists(root)
batch_size = 32
num_workers = 8
transform = TF.Compose([
TF.Resize((5, 5)),
TF.Grayscale(1),
TF.ToTensor(),
])
n_class = 5
clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1]
ds = MyClassBalanceDataset(root, transform)
_batchSampler = MyBatchSampler(ds, batch_size, clas_weight)
data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler)
print(f\'dataloader total: {len(data_loader)}\')
for epoch in range(3):
for step, (x, y) in enumerate(data_loader):
print(step, x)
print(\'batch hist:\', torch.histc(y.float(), n_class, 0, n_class - 1))
if __name__ == \'__main__\':
test_1()