PyTorch中,关于model.eval,和torch.no_grad

  • 一直对于model.eval()和torch.no_grad()有些疑惑
  • 之前看博客说,只用torch.no_grad()即可
  • 但是今天查资料,发现不是这样,而是两者都用,因为两者有着不同的作用
  • 引用stackoverflow:

Use both. They do different things, and have different scopes.

with torch.no_grad: disables tracking of gradients in autograd.

model.eval(): changes the forward() behaviour of the module it is called upon. eg, it disables dropout and has batch norm use the entire population statistics