乐于分享
好东西不私藏

PyTorch 源码阅读:为什么把 forward 写成一个 Callable 属性,而不是普通方法?

PyTorch 源码阅读:为什么把 forward 写成一个 Callable 属性,而不是普通方法?

最近在阅读 PyTorch 核心源码 torch.nn.modules.module.py 时,我注意到了一个非常反直觉,但细想之下极具工程智慧的细节。 在我们的日常认知中,forward 是所有神经网络模块的核心方法。当我们自定义一个模型时,第一件事就是继承 nn.Module 并重写 forward。按照面向对象的惯例,基类 nn.Module 应该定义一个标准的方法签名,或者至少是一个抽象方法。 但在 PyTorch 源码里,情况却有些不同:

ounter(lineounter(lineounter(line# torch/nn/modules/module.pyclass Module:    forward: Callable[..., Any] = _forward_unimplemented

可以看到,forward 并没有被定义为类的一个成员方法(def forward(self, ...)),而是被声明为一个类型为 Callable 的类属性,并将一个名为 _forward_unimplemented 的函数赋值给了它。 更有趣的是,在该文件顶部不仅定义了这个默认实现,还留下了一段意味深长的注释:

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line# Trick mypy into not applying contravariance rules to inputs by defining# forward as a value, rather than a function.  See also# https://github.com/python/mypy/issues/8795def _forward_unimplemented(self, *inputAny) -> None:    r"""Defines the computation performed at every call.    Should be overridden by all subclasses.    ...    """    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")

注释直言不讳:Trick mypy into not applying contravariance rules…(为了骗过 mypy,不让它应用逆变规则…) 这行代码背后,其实隐藏着 Python 动态特性、静态类型检查(Static Typing)与深度学习框架工程实践之间的激烈博弈。

1. 核心冲突:签名的多样性 vs. 里氏替换原则

假设 PyTorch 按照教科书式的面向对象规范设计,将 forward 定义为一个普通的实例方法:

ounter(lineounter(lineounter(lineounter(lineclass Module:	  # 父类可以接受1个任意类型的 Tensor 参数    def forward(selfx:Tensor) -> Tensor:        return x

按照里氏替换原则(Liskov Substitution Principle, LSP),子类在重写父类方法时,必须保证能安全地替换掉父类。对于参数而言,这遵循逆变(Contravariance)规则:子类接受的参数类型必须是父类参数类型的超集。

如果父类声明 forward 接受一切 Tensor(包括浮点、长整型、布尔等),那么它就向外界做出了一个承诺。子类在继承时,如果将参数类型收窄为特定的 LongTensor,本质上是违背了父类的承诺,因为收窄了接口的兼容性。从类型系统的视角看,这种收窄导致子类能处理的场景变少了,一旦某个函数按照父类的契约传入了普通的 FloatTensor,这个原本合法的调用在子类这里就会失效。

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineclass IntEmbedding(Module):    # 子类要求传入 LongTensor,比父类收窄了参数范围    def forward(selfx: torch.LongTensor) -> Tensor:        return xdef run_strict_test(model:Module):    # 根据父类契约传一个浮点 Tensor    data = torch.randn(110)    return model.forward(data)run_strict_test(IntEmbedding()) 

这段代码可被 Python 解释器正常运行,但用 MyPy 执行静态类型检查时,就会遇到问题。

ounter(lineounter(lineounter(lineounter(lineounter(linemypy test.pytest.py:102: error: Argument 1 of "forward" is incompatible with supertype "Module"; supertype defines the argument type as "Tensor"  [override]test.py:102: note: This violates the Liskov substitution principletest.py:102: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overridesFound 1 error in 1 file (checked 1 source file)

另一个直观的冲突在于参数个数(Arity)。如果基类约定 forward 只接受一个 x,而子类(比如 RNN)需要 (input, hidden),或者 Transformer 需要 (input_ids, attention_mask)

ounter(lineounter(lineounter(lineclass MyRNN(Module):    def forward(self, x: Tensor, hidden: Tensor) -> Tensor:        return x

如果用 MyPy 扫描这段代码,会直接收到报错。因为新增了必选参数,破坏了父类的调用契约。

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line$ mypy test.pytest.py:10: error: Signature of "forward" incompatible with supertype "Module"  [override]test.py:10: note:      Superclass:test.py:10: note:          def forward(self, x: Tensor) -> Tensortest.py:10: note:      Subclass:test.py:10: note:          def forward(self, x: Tensor, hidden: Tensor) -> TensorFound 1 errors in 1 file (checked 1 source file)

但在深度学习中,层与层的接口差异是天然存在的。强制统一签名是不可能的,强制所有层都用 (*args, **kwargs) 又会失去代码提示和可读性。PyTorch 将 forward 定义为属性而非方法,正是为了切断这种严格的继承约束。

2. PyTorch 的“欺骗”艺术

为了解决这个问题,PyTorch 并没有选择让运行时妥协,而是选择在类型系统层面做一个“伪装”。 当写下 forward: Callable[..., Any] = _forward_unimplemented 时,发生了一下两件事:

  1. 骗过静态检查(Type Checker): 在 MyPy 眼里,forward 不再是一个需要严格重写的“方法”,而是一个“属性”。

    这在 Python 类型社区是一个已知的 Hack,正如注释中引用的 MyPy Issue #8795 所讨论的,Guido van Rossum 本人也参与了讨论:这本质上是通过放宽类型定义来绕过 LSP 检查。

    • 父类声明:这个属性是一个 Callable,参数不限(...),返回不限(Any)。
    • 子类行为:当在子类写 def forward(self, x, mask): 时,MyPy 认为是在用一个具体的函数对象覆盖(Override)这个属性。
    • 结果:由于 Callable[..., Any] 的包容性极强,无论写成什么签名,都被认为是合法的属性赋值。
  2. 保持运行时行为(Runtime): 虽然在类型定义上它是属性,但在 Python 运行时,函数是描述符(Descriptor)。当通过实例访问 self.forward 时,如果它指向的是一个函数,Python 依然会自动绑定 self。 所以,运行时它依然表现得像一个方法,完全不影响我们将模型导出或调用。

示例 A:传统继承

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineclass Parent:    def call_me(self, x: int) -> None:        passclass Child(Parent):    # Error: Signature incompatible with supertype    def call_me(self, x: int, y: int) -> None:        pass

MyPy 报错

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line$ mypy test.pytest.py:9: error: Signature of "call_me" incompatible with supertype "StandardParent"  [override]test.py:9: note:      Superclass:test.py:9: note:          def call_me(self, xint) -> Nonetest.py:9: note:      Subclass:test.py:9: note:          def call_me(self, xint, yint) -> NoneFound 1 error in 1 file (checked 1 source file)

示例 B:PyTorch 的 Hack 方式

ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(linefrom typing import CallableAnydef _unimplemented(self, *args: Any) -> None:    raise NotImplementedErrorclass Parent:    # 声明为一个 Callable 类型的属性    call_me: Callable[..., Any] = _unimplementedclass Child(Parent):    # 此时随意定义签名,MyPy 不会报错    def call_me(self, x: int, y: int) -> None:        print(x, y)

MyPy 通过:

ounter(lineounter(line$ mypy test.pySuccess: no issues found in 1 source file

3. 工程视角:为什么不用 @abstractmethod?

有人可能会问,为什么不直接用 abc.ABCMeta 和 @abstractmethod? 除了前面提到的签名冲突问题(抽象方法也需要定义签名),还有一个更现实的工程原因:Hook 机制与调用入口的统一。 仔细看 _forward_unimplemented 的文档字符串,有一句警告:

…one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks…

PyTorch 极度依赖 __call__ 来管理生命周期。请看 module.py 中 __call__ 的实现:

ounter(line__call__ : Callable[..., Any] = _wrapped_call_impl

以及它调用的 _call_impl

ounter(lineounter(lineounter(lineounter(lineounter(linedef _call_impl(self, *args, **kwargs):        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)        # If we don't have any hooks, we want to skip the rest of the logic...        if not (self._backward_hooks or ...):            return forward_call(*args, **kwargs)

PyTorch 不希望用户直接调用 model.forward(x),而是希望调用 model(x)。 如果使用抽象方法,虽然能强制子类实现 forward,但无法提供这种“默认的防御性编程”。 通过将基类的 forward 指向 _forward_unimplemented,PyTorch 实现了两个目的:

  • 友好的报错:如果忘了写 forward,调用时会抛出清晰的 NotImplementedError,而不是奇怪的 AttributeError
  • 文档提醒:这个默认函数的文档明确告诉用户去调用 model(x),而不是 model.forward(x)

总结

forward: Callable[..., Any] = _forward_unimplemented 这行代码,是 PyTorch 工程团队在 Python 动态灵活性静态类型检查严格性 以及 框架功能完备性 三者之间找到的一个绝妙平衡点。 它牺牲了一点点代码的直观性(看起来像赋值),换来了:

  • IDE 的智能提示(被标注为 Callable)。
  • 自由的参数定义(绕过了 LSP 逆变检查)。
  • 运行时的安全性(忘记实现有明确报错)。
本站文章均为手工撰写未经允许谢绝转载:夜雨聆风 » PyTorch 源码阅读:为什么把 forward 写成一个 Callable 属性,而不是普通方法?

评论 抢沙发

1 + 6 =
  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
×
订阅图标按钮