PyTorch 源码阅读:为什么把 forward 写成一个 Callable 属性,而不是普通方法?
最近在阅读 PyTorch 核心源码 torch.nn.modules.module.py 时,我注意到了一个非常反直觉,但细想之下极具工程智慧的细节。 在我们的日常认知中,forward 是所有神经网络模块的核心方法。当我们自定义一个模型时,第一件事就是继承 nn.Module 并重写 forward。按照面向对象的惯例,基类 nn.Module 应该定义一个标准的方法签名,或者至少是一个抽象方法。 但在 PyTorch 源码里,情况却有些不同:
ounter(lineounter(lineounter(line# torch/nn/modules/module.pyclass Module:forward: Callable[..., Any] = _forward_unimplemented
可以看到,forward 并没有被定义为类的一个成员方法(def forward(self, ...)),而是被声明为一个类型为 Callable 的类属性,并将一个名为 _forward_unimplemented 的函数赋值给了它。 更有趣的是,在该文件顶部不仅定义了这个默认实现,还留下了一段意味深长的注释:
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line# Trick mypy into not applying contravariance rules to inputs by defining# forward as a value, rather than a function. See also# https://github.com/python/mypy/issues/8795def _forward_unimplemented(self, *input: Any) -> None:r"""Defines the computation performed at every call.Should be overridden by all subclasses...."""raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
注释直言不讳:Trick mypy into not applying contravariance rules…(为了骗过 mypy,不让它应用逆变规则…) 这行代码背后,其实隐藏着 Python 动态特性、静态类型检查(Static Typing)与深度学习框架工程实践之间的激烈博弈。
1. 核心冲突:签名的多样性 vs. 里氏替换原则
假设 PyTorch 按照教科书式的面向对象规范设计,将 forward 定义为一个普通的实例方法:
ounter(lineounter(lineounter(lineounter(lineclass Module:# 父类可以接受1个任意类型的 Tensor 参数def forward(self, x:Tensor) -> Tensor:return x
按照里氏替换原则(Liskov Substitution Principle, LSP),子类在重写父类方法时,必须保证能安全地替换掉父类。对于参数而言,这遵循逆变(Contravariance)规则:子类接受的参数类型必须是父类参数类型的超集。
如果父类声明 forward 接受一切 Tensor(包括浮点、长整型、布尔等),那么它就向外界做出了一个承诺。子类在继承时,如果将参数类型收窄为特定的 LongTensor,本质上是违背了父类的承诺,因为收窄了接口的兼容性。从类型系统的视角看,这种收窄导致子类能处理的场景变少了,一旦某个函数按照父类的契约传入了普通的 FloatTensor,这个原本合法的调用在子类这里就会失效。
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineclass IntEmbedding(Module):# 子类要求传入 LongTensor,比父类收窄了参数范围def forward(self, x: torch.LongTensor) -> Tensor:return xdef run_strict_test(model:Module):# 根据父类契约传一个浮点 Tensordata = torch.randn(1, 10)return model.forward(data)run_strict_test(IntEmbedding())
这段代码可被 Python 解释器正常运行,但用 MyPy 执行静态类型检查时,就会遇到问题。
ounter(lineounter(lineounter(lineounter(lineounter(linemypy test.pytest.py:102: error: Argument 1 of "forward" is incompatible with supertype "Module"; supertype defines the argument type as "Tensor" [override]test.py:102: note: This violates the Liskov substitution principletest.py:102: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overridesFound 1 error in 1 file (checked 1 source file)
另一个直观的冲突在于参数个数(Arity)。如果基类约定 forward 只接受一个 x,而子类(比如 RNN)需要 (input, hidden),或者 Transformer 需要 (input_ids, attention_mask):
ounter(lineounter(lineounter(lineclass MyRNN(Module):def forward(self, x: Tensor, hidden: Tensor) -> Tensor:return x
如果用 MyPy 扫描这段代码,会直接收到报错。因为新增了必选参数,破坏了父类的调用契约。
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line$ mypy test.pytest.py:10: error: Signature of "forward" incompatible with supertype "Module" [override]test.py:10: note: Superclass:test.py:10: note: def forward(self, x: Tensor) -> Tensortest.py:10: note: Subclass:test.py:10: note: def forward(self, x: Tensor, hidden: Tensor) -> TensorFound 1 errors in 1 file (checked 1 source file)
但在深度学习中,层与层的接口差异是天然存在的。强制统一签名是不可能的,强制所有层都用 (*args, **kwargs) 又会失去代码提示和可读性。PyTorch 将 forward 定义为属性而非方法,正是为了切断这种严格的继承约束。
2. PyTorch 的“欺骗”艺术
为了解决这个问题,PyTorch 并没有选择让运行时妥协,而是选择在类型系统层面做一个“伪装”。 当写下 forward: Callable[..., Any] = _forward_unimplemented 时,发生了一下两件事:
-
骗过静态检查(Type Checker): 在 MyPy 眼里,
forward不再是一个需要严格重写的“方法”,而是一个“属性”。这在 Python 类型社区是一个已知的 Hack,正如注释中引用的 MyPy Issue #8795 所讨论的,Guido van Rossum 本人也参与了讨论:这本质上是通过放宽类型定义来绕过 LSP 检查。
-
父类声明:这个属性是一个 Callable,参数不限(...),返回不限(Any)。 -
子类行为:当在子类写 def forward(self, x, mask):时,MyPy 认为是在用一个具体的函数对象覆盖(Override)这个属性。 -
结果:由于 Callable[..., Any]的包容性极强,无论写成什么签名,都被认为是合法的属性赋值。 -
保持运行时行为(Runtime): 虽然在类型定义上它是属性,但在 Python 运行时,函数是描述符(Descriptor)。当通过实例访问
self.forward时,如果它指向的是一个函数,Python 依然会自动绑定self。 所以,运行时它依然表现得像一个方法,完全不影响我们将模型导出或调用。
示例 A:传统继承
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineclass Parent:def call_me(self, x: int) -> None:passclass Child(Parent):# Error: Signature incompatible with supertypedef call_me(self, x: int, y: int) -> None:pass
MyPy 报错
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line$ mypy test.pytest.py:9: error: Signature of "call_me" incompatible with supertype "StandardParent" [override]test.py:9: note: Superclass:test.py:9: note: def call_me(self, x: int) -> Nonetest.py:9: note: Subclass:test.py:9: note: def call_me(self, x: int, y: int) -> NoneFound 1 error in 1 file (checked 1 source file)
示例 B:PyTorch 的 Hack 方式
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(linefrom typing import Callable, Anydef _unimplemented(self, *args: Any) -> None:raise NotImplementedErrorclass Parent:# 声明为一个 Callable 类型的属性call_me: Callable[..., Any] = _unimplementedclass Child(Parent):# 此时随意定义签名,MyPy 不会报错def call_me(self, x: int, y: int) -> None:print(x, y)
MyPy 通过:
ounter(lineounter(line$ mypy test.pySuccess: no issues found in 1 source file
3. 工程视角:为什么不用 @abstractmethod?
有人可能会问,为什么不直接用 abc.ABCMeta 和 @abstractmethod? 除了前面提到的签名冲突问题(抽象方法也需要定义签名),还有一个更现实的工程原因:Hook 机制与调用入口的统一。 仔细看 _forward_unimplemented 的文档字符串,有一句警告:
…one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks…
PyTorch 极度依赖 __call__ 来管理生命周期。请看 module.py 中 __call__ 的实现:
ounter(line__call__ : Callable[..., Any] = _wrapped_call_impl
以及它调用的 _call_impl:
ounter(lineounter(lineounter(lineounter(lineounter(linedef _call_impl(self, *args, **kwargs):forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)# If we don't have any hooks, we want to skip the rest of the logic...if not (self._backward_hooks or ...):return forward_call(*args, **kwargs)
PyTorch 不希望用户直接调用 model.forward(x),而是希望调用 model(x)。 如果使用抽象方法,虽然能强制子类实现 forward,但无法提供这种“默认的防御性编程”。 通过将基类的 forward 指向 _forward_unimplemented,PyTorch 实现了两个目的:
-
友好的报错:如果忘了写 forward,调用时会抛出清晰的NotImplementedError,而不是奇怪的AttributeError。 -
文档提醒:这个默认函数的文档明确告诉用户去调用 model(x),而不是model.forward(x)。
总结
forward: Callable[..., Any] = _forward_unimplemented 这行代码,是 PyTorch 工程团队在 Python 动态灵活性、静态类型检查严格性 以及 框架功能完备性 三者之间找到的一个绝妙平衡点。 它牺牲了一点点代码的直观性(看起来像赋值),换来了:
-
IDE 的智能提示(被标注为 Callable)。 -
自由的参数定义(绕过了 LSP 逆变检查)。 -
运行时的安全性(忘记实现有明确报错)。

夜雨聆风
