Pytorch iter问题
Pytorch:RuntimeError: randperm is only implemented for CPU
/anaconda3/lib/python3.6/site-packages/torch/utils/data/sampler.py文件中的__iter__函数改为如下形式:
class RandomSampler(Sampler):
r"""Samples elements randomly, without replacement.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
#print('torch.device:',torch.device)
def __iter__(self):
cpu = torch.device('cpu',0)
return iter(torch.randperm(len(self.data_source),device=cpu).tolist())
def __len__(self):
return len(self.data_source)