pytorch上space2depth的实现

申明:转载注明出处

https://www.cnblogs.com/wioponsen/p/14570312.html

方式一 使用torch.nn.functional.unfold

def space_to_depth(in_tensor, down_scale):
    n, c, h, w = in_tensor.size()
    unfolded_x = torch.nn.functional.unfold(in_tensor, down_scale, stride=down_scale)
    return unfolded_x.view(n, c * down_scale ** 2, h // down_scale, w // down_scale)

方式二 使用view+permute

def space_to_depth(in_tensor, down_scale):
    Batchsize, Ch, Height, Width = in_tensor.size()
    out_channel = Ch * (down_scale ** 2)
    out_Height = Height // down_scale
    out_Width = Width // down_scale

    in_tensor_view = in_tensor.view(Batchsize * Ch, out_Height, down_scale, out_Width, down_scale)
    output = in_tensor_view.permute(0, 2, 4, 1, 3).contiguous().view(Batchsize, out_channel, out_Height, out_Width)
    return output