pytorch中y.data.norm,的含义

import torch
x = torch.randn(3, requires_grad=True)
y = x*2
print(y.data.norm())
print(torch.sqrt(torch.sum(torch.pow(y,2))))  #其实就是对y张量L2范数,先对y中每一项取平方,之后累加,最后取根号
i=0
while y.data.norm()<1000:
  y = y*2
  i+=1
print(y)
print(i)

结果:

tensor(3.7025)
tensor(3.7025, grad_fn=<SqrtBackward>)
tensor([ 1066.4563, -1511.3652,  -414.6933], grad_fn=<MulBackward0>)
9