pytorch detach函数

用于截断反向传播

detach()源码:

def detach(self):
    result = NoGrad()(self)  # this is needed, because it merges version counters
    result._grad_fn = None
    return result

它的返回结果与调用者共享一个data tensor,且会将grad_fn设为None,这样就不知道该Tensor是由什么操作建立的,截断反向传播

这个时候再一个tensor使用In_place操作会导致另一个的data tensor也会发生改变

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
out = a.sigmoid()
print(out)#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)

c = out.detach()
print(c)#tensor([0.7311, 0.8808, 0.9526])

这个时候可以看到,c和out的区别就是一个有grad_fn,一个没有grad_fn

执行out.sum().backward()没有问题,但执行c.sum().backward()报错:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

这个时候不论是对out还是对c进行inplace操作改变它们的data,这个改动会被autograd追踪,这个时候再执行out.sum().backward()会报错

假设对out进行inplace操作,会出现:

out.zero_()
#tensor([0., 0., 0.], grad_fn=<ZeroBackward>)

out.sum().backward()
#报错

错误信息为

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

如果不对out进行inplace操作而是对c进行inplace操作,结果是一样的,Out不能再进行反向传播了

为了解决这种情况,就要对tensor的data操作,使其不被autograd记录

重新得到一个out,把它的data部分给c

c = out.data
#tensor([0.7311, 0.8808, 0.9526])

out
#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)

这里可以看到,c中没有Out中有的grad_fn信息

这回修改c的值,发现out的data值依然改了,但是执行out.sum().backward()不报错了

detach_()

def detach_(self):
    """Detaches the Variable from the graph that created it, making it a leaf.
    """
    self._grad_fn = None
    self.requires_grad = False

做了两件事:1grad_fn设none2requires_grad设false

它不会新生成一个Variable而是使用原来的variable