pytest的"透视眼":断言重写背后的魔法

前言:一个困扰测试新手的谜题

你是否也曾有过这样的疑惑:

def test_calculate():
    result = calculate(2, 3)
    expected = 5
    assert result == expected

当这个测试失败时,pytest竟然能告诉你:

assert 4 == 5

等等,它怎么知道result的值是4,expected的值是5?难道pytest偷偷在你的代码里插了眼?

更神奇的是,当你写下更复杂的断言:

def test_complex_assertion():
    user_data = {"name": "Alice", "age": 25}
    assert user_data["name"] == "Bob" and user_data["age"] > 30

失败时,它会告诉你:

assert ({"name": "Alice", "age": 25}["name"] == "Bob" and {"name": "Alice", "age": 25}["age"] > 30)

甚至还能拆解成更详细的形式!

今天,我们就来揭开这个"魔法"背后的秘密——pytest断言重写(Assertion Rewriting)

标准assert的"哑巴"表现

在开始之前,先看看标准Python的assert有多"冷漠":

def test_with_standard_assert():
    result = 4
    expected = 5
    assert result == expected

python -m pytest test_file.py运行时,你会看到详细的错误信息。但如果你直接用python test_file.py运行(当然你得先import一个不存在的函数让它报错,或者自己写个测试框架),你会得到:

AssertionError

就这?对,就这。标准Python的assert就像个不善言辞的理工男,只告诉你"错了",但不说"哪里错了"、"为什么错"。

这是因为Python的assert本质上就是一个简单的条件检查,失败时抛出AssertionError异常,仅此而已。

pytest的"偷梁换柱":断言重写

pytest是如何让assert变得如此"话痨"的?答案就在断言重写技术。

工作原理:在字节码层面动手脚

pytest的断言重写工作流程如下:

  1. 拦截导入:当pytest加载测试模块时,它会拦截这个导入过程
  2. 字节码转换:pytest使用Python的ast模块解析源代码为抽象语法树(AST),然后修改AST,再重新编译为字节码
  3. 魔法替换:将原本简单的assert expression替换为更复杂的表达式,包含变量值的捕获和格式化逻辑

让我们看一个简化版的例子:

原始代码

assert x == y

重写后的代码(简化版):

if not (x == y):
    raise AssertionError(
        f"assert {repr(x)} == {repr(y)}\n"
        f"  + where {repr(x)} = {type(x).__name__}(...)\n"
        f"  + and {repr(y)} = {type(y).__name__}(...)"
    )

当然,pytest的实现要复杂得多,但基本思路就是这样——在测试代码真正运行前,先把assert语句"动过手术"。

深入源码:pytest的魔法师们

如果你想亲眼看看pytest是怎么做到的,可以看看pytest源码中的这些关键模块:

  • _pytest.assertion.rewrite:断言重写的核心模块
  • _pytest.assertion.util:断言信息的格式化工具
  • _pytest.assertion.tracer:用于追踪表达式的执行

最有意思的是rewrite.py中的AssertionRewritingHook类,它是一个自定义的导入钩子:

class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
    # ... 省略细节 ...

    def find_module(self, name, path=None):
        # 检查是否是测试模块,如果是,返回自己作为loader
        if self._should_rewrite(name):
            return self
        return None

    def load_module(self, name):
        # 加载模块时,先重写断言,再返回模块
        source = self._get_source(name)
        rewritten = self._rewrite_assertions(source)
        return self._exec_module(rewritten, name)

当然,实际代码比我写的这个简化版复杂得多,但核心思想就是:在模块导入时,先把代码改一改,再让Python执行修改后的代码

断言重写的实战案例

案例1:简单比较

def test_simple_comparison():
    a = 42
    b = 100
    assert a < b

pytest输出:

assert 42 < 100

案例2:复杂的表达式

def test_complex_expression():
    numbers = [1, 2, 3, 4, 5]
    assert sum(numbers[:3]) * 2 == len(numbers) * 3

pytest输出:

assert 12 == 15

pytest不仅计算了sum(numbers[:3]) * 2的值(12),还计算了len(numbers) * 3的值(15),然后告诉你它们不相等。

案例3:字典比较

def test_dict_comparison():
    actual = {"name": "Alice", "age": 30, "city": "Beijing"}
    expected = {"name": "Alice", "age": 25, "city": "Shanghai"}
    assert actual == expected

pytest输出:

assert {'name': 'Alice', 'age': 30, 'city': 'Beijing'} == {'name': 'Alice', 'age': 25, 'city': 'Shanghai'}
  Common items:
    {'name': 'Alice'}
  Differing items:
    {'age': 30} != {'age': 25}
    {'city': 'Beijing'} != {'city': 'Shanghai'}

pytest甚至还能智能地比较字典,告诉你哪些字段相同,哪些字段不同!

如何禁用断言重写

有时候,你可能不想要断言重写(比如调试时),pytest提供了几种禁用方式:

方法1:命令行参数

pytest --assert=plain test_file.py

这样pytest就会使用标准的Python assert,不再显示变量值。

方法2:在pyproject.toml中配置

[tool.pytest.ini_options]
assertmode = "plain"

方法3:代码中临时禁用

import pytest

def test_something():
    # 这个assert会被重写
    assert True

    # 临时禁用断言重写
    with pytest.raises(AssertionError, match=r"plain assertion"):
        assert False  # 这行不会被重写,因为pytest.raises会阻止重写

断言重写的限制

虽然断言重写很强大,但它也有一些限制:

  1. 仅对测试文件生效:只有符合test_*.py*_test.py命名规则的文件才会被重写
  2. 不支持生成器表达式:在生成器表达式中使用assert不会被重写
  3. 不支持异步函数(虽然pytest 7.0+已经支持异步测试,但断言重写对异步代码的支持有限)
  4. 性能开销:断言重写需要在导入时解析和修改字节码,会有轻微的性能开销(但通常可以忽略不计)

内幕:pytest是如何"偷看"你的变量的

现在,我们来揭晓最终的秘密:pytest究竟是如何在不修改你源代码的情况下,"偷看"到你的变量值的?

答案就是:它在你的代码真正运行前,先把代码改了

当你运行pytest test_file.py时:

  1. pytest会创建一个自定义的导入钩子(AssertionRewritingHook
  2. 当Python尝试导入test_file模块时,pytest拦截了这个导入
  3. pytest读取test_file.py的源代码
  4. pytest使用ast模块解析源代码,得到抽象语法树
  5. pytest遍历AST,找到所有assert语句
  6. pytest将这些assert语句替换为更复杂的表达式,包含变量值的捕获和格式化逻辑
  7. pytest将修改后的AST重新编译为字节码
  8. pytest返回这个修改后的模块给Python
  9. Python执行这个模块时,实际上执行的是被pytest修改过的代码

整个过程就像一个"特工",在你的代码真正执行前,先悄悄地把代码改一改,再让Python执行修改后的代码。

而这一切,对你来说是完全透明的——你只写了一次assert,但pytest让它变成了一个"话痨"。

总结:断言重写的艺术

pytest的断言重写技术是一个精妙的工程,它在不改变你代码风格的前提下,提供了极其有用的调试信息。它的核心思想是:

  1. 利用Python的导入机制:通过自定义导入钩子,拦截模块导入
  2. 操作抽象语法树:在AST层面修改代码,而不是直接修改源代码
  3. 透明的魔法:对用户来说,整个过程是透明的,不需要任何额外操作

这就像是pytest在你不知情的情况下,悄悄给你的assert语句"装了个麦克风",让它能在失败时大喊:"我知道你哪里错了!"

彩蛋:你也可以写一个断言重写器

如果你想挑战一下,可以尝试写一个简单的断言重写器:

import ast

class AssertRewriter(ast.NodeTransformer):
    def visit_Assert(self, node):
        # 简单的重写:添加一条打印语句
        print_stmt = ast.Expr(
            value=ast.Call(
                func=ast.Name(id='print', ctx=ast.Load()),
                args=[ast.Constant(value="断言被执行了!")],
                keywords=[]
            )
        )
        # 在assert之前插入print语句
        ast.fix_missing_locations(print_stmt)
        node.body.insert(0, print_stmt)
        return node

def rewrite_assertions(source_code):
    tree = ast.parse(source_code)
    rewriter = AssertRewriter()
    new_tree = rewriter.visit(tree)
    return ast.unparse(new_tree)

# 测试
original = """
def test_something():
    assert True
"""

rewritten = rewrite_assertions(original)
print(rewritten)

输出:

def test_something():
    print("断言被执行了!")
    assert True

当然,这个简单的例子只是抛砖引玉。真正的pytest断言重写器要复杂得多,它需要处理各种边缘情况,生成详细的错误信息,还要保持性能。

结语

pytest的断言重写技术展示了Python的强大和灵活性——你可以在不修改源代码的情况下,在运行时"修改"代码的行为。这种技术不仅有用,而且很酷。

下次当你看到pytest的详细错误信息时,记得感谢那些幕后英雄——那些在你的代码真正运行前,悄悄把代码改了一下的"特工们"。

Happy Testing! 🚀

声明:本站所有文章,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。-- mikigo