pytorch 调用forward 的具体流程

forward方法的具体流程:

以一个Module为例:

1. 调用module的call方法

2. module的call里面调用module的forward方法

3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下

4. 调用Function的call方法

5. Function的call方法调用了Function的forward方法。

6. Function的forward返回值

7. module的forward返回值

8. 在module的call进行forward_hook操作,然后返回值。

上述中“调用module的call方法”是指nn.Module 的__call__方法。定义__call__方法的类可以当作函数调用,具体参考Python的面向对象编程。

也就是说,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。nn.Module 的__call__方法部分源码如下所示:

def __call__(self, *input, **kwargs):

result = self.forward(*input, **kwargs)

for hook in self._forward_hooks.values():

#将注册的hook拿出来用

hook_result = hook(self, input, result)

...

return result