Skip to content
话题
Python
Python Assert:更聪明地调试代码

Python Assert:更聪明地调试代码

更新于

你正在调试一个数据处理函数。在读取 CSV 和写入输出之间的某个环节,本该为整数的一列变成了浮点数,一个绝不应该为空的列表变空了,而一个本该为正数的用户 ID 变成了 -1。你到处散落 print() 语句,重新运行脚本,眯着眼看输出,然后重复同样的过程。一小时后你找到了 bug——某个函数在不知不觉中接受了错误输入,并把损坏的数据继续传了下去。真正的失败发生在崩溃前五十行,而没有任何东西提醒你。

这是 Python 开发中的一个普遍问题。bug 会悄无声息地传播。一个 None 值穿过三个函数调用后,才触发一个 AttributeError。一个负数组索引回绕到了错误的元素。一个本该有五个键的字典只剩下四个,而缺失的键导致了一个只有在生产环境才会暴露的微妙逻辑错误。等错误变得可见时,你已经失去了所有上下文,不知道问题究竟最初发生在哪里。

Python 的 assert 语句正是为了解决这个问题:在失败发生的那个精确位置捕获 bug,并给出清晰的错误信息。与其寄希望于坏数据最终会引发一个明显的崩溃,不如明确声明你的假设——然后让 Python 立即替你强制执行。

什么是 Assert 语句?

assert 语句用于测试一个条件。如果条件为 True,什么都不会发生。如果条件为 False,Python 会立即抛出一个 AssertionError

assert 2 + 2 == 4      # Passes silently
assert 2 + 2 == 5      # Raises AssertionError

基本语法是:

assert condition
assert condition, "Error message explaining what went wrong"

在内部,Python 会把 assert 翻译成一个 if 语句。其等价形式如下:

# assert condition, message
# is equivalent to:
if __debug__:
    if not condition:
        raise AssertionError(message)

__debug__ 变量默认是 True。只有当 Python 使用 -O(optimize)标志运行时,它才会变成 False。这意味着断言可以在生产环境中被完全禁用——这一点有重要影响,我们后面会讨论。

当断言失败时,会发生如下情况:

x = -1
assert x >= 0, f"Expected non-negative value, got {x}"

输出:

Traceback (most recent call last):
  File "example.py", line 2, in <module>
    assert x >= 0, f"Expected non-negative value, got {x}"
AssertionError: Expected non-negative value, got -1

这个 traceback 会精确指向假设被违反的那一行,而消息则清楚说明了问题。把它和二十行之后才出现的神秘 IndexError 对比一下——后者是因为负值未经检查一路传播下去了。

Assert 的基本用法

简单断言

最简单的断言会检查一个单独条件:

# Check that a variable is not None
config = load_config()
assert config is not None
 
# Check that a list is not empty
items = get_items()
assert len(items) > 0
 
# Check a mathematical property
result = calculate_discount(price=100, percent=20)
assert result == 80

带自定义消息的断言

一定要包含消息。没有消息的话,断言失败时几乎没有上下文:

# Bad: no message
assert len(users) > 0
 
# Good: descriptive message
assert len(users) > 0, "User list is empty -- database query may have failed"
 
# Good: include the actual value
assert temperature >= -273.15, f"Temperature {temperature}C is below absolute zero"

消息是 assert 的第二个参数,用逗号分隔。它可以是任何能生成字符串的表达式,包括带运行时值的 f-string:

def process_batch(items, batch_size):
    assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
    assert len(items) >= batch_size, (
        f"Not enough items: need {batch_size}, have {len(items)}"
    )
    # Process the batch...

带括号的断言——一个常见陷阱

有一个微妙的 bug 经常困扰 Python 开发者:

# WARNING: This assertion NEVER fails!
assert(condition, "error message")

这会创建一个元组 (condition, "error message")。非空元组永远是真值,因此这个断言总会通过。Python 甚至会警告你:

SyntaxWarning: assertion is always true, perhaps remove parentheses?

正确写法是:

# Correct: no parentheses
assert condition, "error message"
 
# Also correct: parentheses only around the condition
assert (condition), "error message"
 
# Also correct: multi-line with implicit line continuation
assert (
    very_long_condition_that_needs_wrapping
), "error message"

复杂条件下的 Assert

多个条件

你可以用 andornot 组合条件:

def create_user(name, age, email):
    assert name and isinstance(name, str), f"Invalid name: {name!r}"
    assert 0 < age < 150, f"Invalid age: {age}"
    assert "@" in email and "." in email, f"Invalid email format: {email}"
 
    # Proceed with user creation...

使用 isinstance 进行类型检查

在开发过程中,可以用 isinstance 断言验证数据类型:

def calculate_mean(values):
    assert isinstance(values, (list, tuple)), (
        f"Expected list or tuple, got {type(values).__name__}"
    )
    assert all(isinstance(v, (int, float)) for v in values), (
        "All values must be numeric"
    )
    assert len(values) > 0, "Cannot calculate mean of empty sequence"
 
    return sum(values) / len(values)

如果需要在生产环境进行类型检查,可以考虑使用 Python type hints 搭配 mypy 这样的静态类型检查器。断言的用途是在开发阶段捕获 bug,而不是在运行时强制类型约束。

容器与集合检查

# Check dictionary has required keys
required_keys = {"name", "email", "role"}
assert required_keys.issubset(user_data.keys()), (
    f"Missing keys: {required_keys - user_data.keys()}"
)
 
# Check list contains no duplicates
ids = [item.id for item in items]
assert len(ids) == len(set(ids)), (
    f"Duplicate IDs found: {[x for x in ids if ids.count(x) > 1]}"
)
 
# Check that all elements satisfy a condition
scores = [85, 92, 78, 95, 88]
assert all(0 <= s <= 100 for s in scores), (
    f"Scores out of range: {[s for s in scores if not 0 <= s <= 100]}"
)

使用 Assert 的常见调试模式

函数前置条件

前置条件用于在函数开始工作前验证它收到的输入是否合法。把它们放在函数顶部:

def transfer_money(from_account, to_account, amount):
    # Preconditions
    assert from_account != to_account, "Cannot transfer to the same account"
    assert amount > 0, f"Transfer amount must be positive, got {amount}"
    assert from_account.balance >= amount, (
        f"Insufficient funds: balance={from_account.balance}, transfer={amount}"
    )
 
    from_account.balance -= amount
    to_account.balance += amount

函数后置条件

后置条件用于在返回前验证函数是否产生了正确的输出。把它们放在 return 语句前:

def sort_descending(items):
    result = sorted(items, reverse=True)
 
    # Postconditions
    assert len(result) == len(items), "Sort changed the number of elements"
    assert all(result[i] >= result[i+1] for i in range(len(result)-1)), (
        "Result is not sorted in descending order"
    )
 
    return result

循环不变量

循环不变量用于验证在每次循环迭代中某个条件都成立。它们可以捕捉 off-by-one 错误、死循环和逻辑 bug:

def binary_search(sorted_list, target):
    low = 0
    high = len(sorted_list) - 1
 
    while low <= high:
        # Loop invariant: target must be in sorted_list[low:high+1] if it exists
        assert low >= 0 and high < len(sorted_list), (
            f"Bounds out of range: low={low}, high={high}, len={len(sorted_list)}"
        )
 
        mid = (low + high) // 2
        if sorted_list[mid] == target:
            return mid
        elif sorted_list[mid] < target:
            low = mid + 1
        else:
            high = mid - 1
 
    return -1

类不变量

类不变量用于在每次操作后验证对象内部状态是否仍然一致:

class BoundedQueue:
    """A queue with a maximum capacity."""
 
    def __init__(self, capacity):
        assert capacity > 0, f"Capacity must be positive, got {capacity}"
        self._items = []
        self._capacity = capacity
        self._check_invariant()
 
    def _check_invariant(self):
        assert 0 <= len(self._items) <= self._capacity, (
            f"Queue size {len(self._items)} violates capacity {self._capacity}"
        )
 
    def enqueue(self, item):
        assert len(self._items) < self._capacity, "Queue is full"
        self._items.append(item)
        self._check_invariant()
 
    def dequeue(self):
        assert len(self._items) > 0, "Queue is empty"
        item = self._items.pop(0)
        self._check_invariant()
        return item
 
    def __len__(self):
        return len(self._items)

Assert 与 Raise:何时使用哪个

这是 Python 错误处理里最重要的区别之一。assertraise 的目的根本不同。

Featureassertraise
PurposeCatch programmer errors (bugs)Handle runtime conditions (expected failures)
Can be disabledYes, with -O flagNo, always active
Use for input validationNeverYes
Use for external dataNeverYes
Typical exceptionAssertionErrorValueError, TypeError, RuntimeError, etc.
When it firesSomething is wrong with the codeSomething is wrong with the input/environment
AudienceThe developerThe user or calling code
Presence in productionShould not be relied uponRequired

将 assert 用于:内部不变量和开发者假设

def _calculate_tax(income, brackets):
    # Developer assumption: brackets are sorted
    assert all(
        brackets[i][0] <= brackets[i+1][0]
        for i in range(len(brackets) - 1)
    ), "Tax brackets must be sorted by threshold"
 
    # This is a bug in the code if brackets aren't sorted,
    # not a user input problem
    ...

将 raise 用于:输入校验和预期错误条件

def create_account(username, password):
    if not username or len(username) < 3:
        raise ValueError("Username must be at least 3 characters")
    if len(password) < 8:
        raise ValueError("Password must be at least 8 characters")
 
    # These are user input problems, not programmer bugs.
    # They must ALWAYS be checked, even in production.
    ...

关键区别在于:如果有人运行 python -O your_script.py,每个 assert 语句都会被完全移除。如果你用 assert 做输入验证,那么在优化模式下这些验证就消失了。这不是理论风险——许多部署工具和生产环境都会使用 -O 标志。若想更深入了解异常处理模式,可查看 Python try/except guide

经验法则

问问自己:“如果把这个检查完全删掉,用户是否可能造成安全问题或数据损坏?”如果是,就用 raise。如果这个检查只是在捕捉开发者错误(代码本身的 bug),就用 assert

测试中的 Assert

断言是 Python 测试的基石。unittestpytest 都依赖断言来验证预期行为。

pytest 断言

pytest 使用普通的 assert 语句,而不是特殊的断言方法。这是它最大的优点之一——你写的是自然的 Python,而不是去记一堆方法名:

# test_math.py
def test_addition():
    assert 2 + 2 == 4
 
def test_string_methods():
    greeting = "hello world"
    assert greeting.upper() == "HELLO WORLD"
    assert greeting.split() == ["hello", "world"]
 
def test_list_operations():
    items = [1, 2, 3]
    items.append(4)
    assert len(items) == 4
    assert items[-1] == 4

pytest 的 Assert 重写

让 pytest 与众不同的是 assert rewriting。当普通 assert 失败时,Python 只会说 AssertionError。pytest 会在导入时重写你的 assert 语句,提供丰富的失败信息:

def test_comparison():
    result = {"name": "Alice", "age": 30}
    expected = {"name": "Alice", "age": 31}
    assert result == expected

pytest 输出:

FAILED test_example.py::test_comparison - AssertionError: assert {'age': 30, 'name': 'Alice'} == {'age': 31, 'name': 'Alice'}
  Differing items:
  {'age': 30} != {'age': 31}

如果没有 pytest 的重写机制,你通常只会看到 AssertionError,没有任何细节。这个魔法之所以成立,是因为 pytest 使用了导入钩子,在导入时把 assert 语句转换成更详细的检查,并捕获中间值。

常见的 pytest 断言模式

# Check that an exception is raised
import pytest
 
def test_division_by_zero():
    with pytest.raises(ZeroDivisionError):
        1 / 0
 
def test_invalid_input():
    with pytest.raises(ValueError, match="must be positive"):
        create_user(age=-5)
 
# Check approximate equality (for floats)
def test_float_calculation():
    result = 0.1 + 0.2
    assert result == pytest.approx(0.3)
 
# Check that a value is in a collection
def test_membership():
    valid_statuses = {"active", "inactive", "pending"}
    user_status = get_user_status(user_id=42)
    assert user_status in valid_statuses
 
# Check with custom message
def test_data_integrity():
    records = load_records()
    assert len(records) > 0, "No records loaded -- check database connection"

unittest 中的断言

unittest 模块提供的是基于方法的断言,而不是普通 assert。它们不需要 pytest 那样的重写机制,也能提供更好的错误信息:

import unittest
 
class TestStringMethods(unittest.TestCase):
    def test_upper(self):
        self.assertEqual("hello".upper(), "HELLO")
 
    def test_contains(self):
        self.assertIn("world", "hello world")
 
    def test_raises(self):
        with self.assertRaises(TypeError):
            "hello" + 5

两种方式都有效。pytest 的普通 assert 更易读,也更符合 Python 风格。unittest 的方法式断言则在不依赖导入时重写的情况下提供详细信息。

交互式测试断言

在交互式开发和调试测试断言时,像 RunCell (opens in a new tab) 这样的工具允许你在 Jupyter notebooks 中运行单个测试单元,并立即获得反馈。当你一步一步构建复杂的断言条件时,这尤其有用——你可以先单独测试每个断言,再把它们组合成完整的测试套件。

什么时候不应该使用 Assert

这一部分非常关键。错误使用 assert 会制造出只有在生产环境才会出现的隐蔽而危险的 bug。

绝不要用 Assert 做输入验证

# WRONG: This check disappears with python -O
def withdraw(amount):
    assert amount > 0, "Amount must be positive"
    self.balance -= amount
 
# RIGHT: This check always runs
def withdraw(amount):
    if amount <= 0:
        raise ValueError("Amount must be positive")
    self.balance -= amount

绝不要用 Assert 验证来自外部的数据

来自用户、文件、网络、数据库或 API 的数据都可能有问题。这些检查必须始终执行:

# WRONG: Network data validation with assert
def handle_api_response(response):
    assert response.status_code == 200
    data = response.json()
    assert "results" in data
 
# RIGHT: Proper error handling for external data
def handle_api_response(response):
    if response.status_code != 200:
        raise RuntimeError(f"API returned status {response.status_code}")
    data = response.json()
    if "results" not in data:
        raise ValueError("API response missing 'results' field")

绝不要用 Assert 做安全检查

# CATASTROPHICALLY WRONG: Security check with assert
def delete_user(requesting_user, target_user_id):
    assert requesting_user.is_admin, "Only admins can delete users"
    database.delete(target_user_id)
 
# RIGHT: Security check that cannot be disabled
def delete_user(requesting_user, target_user_id):
    if not requesting_user.is_admin:
        raise PermissionError("Only admins can delete users")
    database.delete(target_user_id)

python -O 下,断言版会让任何用户删除任何其他用户。这是真实的安全漏洞。

绝不要在 Assert 中使用副作用

因为断言可以被禁用,断言中的表达式绝不应该包含副作用:

# WRONG: The pop() is a side effect that disappears with -O
assert items.pop() == expected_value
 
# RIGHT: Separate the side effect from the assertion
value = items.pop()
assert value == expected_value

-O 标志:断言是如何消失的

Python 有两个会影响断言的优化级别:

python script.py        # Normal: __debug__ is True, assertions active
python -O script.py     # Optimize: __debug__ is False, assertions removed
python -OO script.py    # Extra optimize: assertions removed + docstrings removed

当 Python 使用 -O 运行时,解释器会把 __debug__ 设置为 False,并把所有 assert 语句从字节码中完全移除。它们不只是被跳过,而是彻底不存在了。条件不会被求值,错误信息也不会被构造出来。

你可以验证这一点:

# check_debug.py
print(f"__debug__ = {__debug__}")
 
if __debug__:
    print("Assertions are ACTIVE")
else:
    print("Assertions are DISABLED")
 
assert False, "This should raise an error"
$ python check_debug.py
__debug__ = True
Assertions are ACTIVE
Traceback (most recent call last):
  File "check_debug.py", line 8
AssertionError: This should raise an error
 
$ python -O check_debug.py
__debug__ = False
Assertions are DISABLED
# No error! The assert was completely removed.

在实践中哪里会用到 -O

  • Docker images: 很多生产环境 Dockerfile 会使用 PYTHONOPTIMIZE=1python -O
  • Deployment tools: 一些 WSGI 服务器会以优化模式运行 Python
  • Performance-sensitive applications: 移除断言可以加快紧密循环
  • Library code: 库代码绝不能假设断言一定会启用,因为优化级别由使用者控制

对你代码的影响

把断言看作施工时的脚手架。它们在你建造时支撑结构,但建筑完成后就会拆掉。你的代码必须在断言存在或不存在时都能正确运行。

# This code works correctly with or without assertions:
def safe_divide(a, b):
    assert isinstance(a, (int, float)), f"Expected number, got {type(a)}"
    assert isinstance(b, (int, float)), f"Expected number, got {type(b)}"
 
    if b == 0:
        raise ValueError("Division by zero")
    return a / b

这些断言帮助你在开发期间捕获 bug。raise 则在生产环境中处理预期错误。两者各司其职。

自定义断言辅助函数

当你发现自己反复编写同一种断言模式时,可以把它提取成可复用的辅助函数。

简单断言函数

def assert_positive(value, name="value"):
    """Assert that a value is a positive number."""
    assert isinstance(value, (int, float)), (
        f"{name} must be a number, got {type(value).__name__}"
    )
    assert value > 0, f"{name} must be positive, got {value}"
 
 
def assert_valid_probability(p, name="probability"):
    """Assert that a value is a valid probability (0 to 1)."""
    assert isinstance(p, (int, float)), (
        f"{name} must be a number, got {type(p).__name__}"
    )
    assert 0 <= p <= 1, f"{name} must be between 0 and 1, got {p}"
 
 
def assert_same_length(*sequences, names=None):
    """Assert that all sequences have the same length."""
    lengths = [len(s) for s in sequences]
    if names:
        details = ", ".join(f"{n}={l}" for n, l in zip(names, lengths))
    else:
        details = ", ".join(str(l) for l in lengths)
    assert len(set(lengths)) == 1, (
        f"Length mismatch: {details}"
    )
 
 
# Usage
def calculate_weighted_average(values, weights):
    assert_same_length(values, weights, names=["values", "weights"])
    assert_valid_probability(sum(weights) / len(weights), "average weight")
 
    return sum(v * w for v, w in zip(values, weights)) / sum(weights)

基于装饰器的断言

你可以使用 decorators 为函数添加前置/后置条件检查,而不让函数主体显得杂乱:

import functools
 
def preconditions(**checks):
    """Decorator that asserts preconditions on function arguments."""
    def decorator(func):
        def wrapper(*args, **kwargs):
            import inspect
            sig = inspect.signature(func)
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()
 
            for param_name, check_func in checks.items():
                value = bound.arguments[param_name]
                assert check_func(value), (
                    f"Precondition failed for '{param_name}': "
                    f"got {value!r}"
                )
            return func(*args, **kwargs)
        return wrapper
    return decorator
 
 
@preconditions(
    x=lambda v: isinstance(v, (int, float)) and v >= 0,
    n=lambda v: isinstance(v, int) and v > 0
)
def nth_root(x, n):
    """Calculate the nth root of x."""
    return x ** (1 / n)
 
 
# This passes
print(nth_root(27, 3))  # 3.0
 
# This fails with a clear message
print(nth_root(-1, 2))  # AssertionError: Precondition failed for 'x': got -1

用于断言组的上下文管理器

当你需要检查多个相关条件,并一次性报告所有失败项时:

class AssertionGroup:
    """Collect multiple assertion failures and report them together."""
 
    def __init__(self, description=""):
        self.description = description
        self.failures = []
 
    def check(self, condition, message):
        if not condition:
            self.failures.append(message)
 
    def verify(self):
        if self.failures:
            header = f"{self.description}: " if self.description else ""
            details = "\n  - ".join(self.failures)
            assert False, f"{header}{len(self.failures)} checks failed:\n  - {details}"
 
 
# Usage
def validate_user_record(record):
    checks = AssertionGroup("User record validation")
    checks.check("name" in record, "Missing 'name' field")
    checks.check("email" in record, "Missing 'email' field")
    checks.check(
        record.get("age", 0) > 0,
        f"Invalid age: {record.get('age')}"
    )
    checks.check(
        "@" in record.get("email", ""),
        f"Invalid email: {record.get('email')}"
    )
    checks.verify()  # Raises with all failures at once

使用 Try/Except 处理 AssertionError

你可以像捕获其他异常一样捕获 AssertionError,不过在应用代码中这通常不是一个好主意:

try:
    assert len(data) > 0, "Data is empty"
    process(data)
except AssertionError as e:
    print(f"Assertion failed: {e}")
    # Handle the failure...

什么时候捕获 AssertionError 是合理的

有少数合法用途:

1. 测试框架:pytest 和 unittest 会捕获 AssertionError,以便报告测试失败,而不是让程序崩溃。

2. 在长时间运行的进程中记录断言失败

import logging
 
logger = logging.getLogger(__name__)
 
def process_records(records):
    failed = []
    for record in records:
        try:
            assert_valid_record(record)
            process(record)
        except AssertionError as e:
            logger.error(f"Skipping invalid record: {e}")
            failed.append(record)
 
    if failed:
        logger.warning(f"{len(failed)} records failed validation")
    return failed

对于生产环境中的 logging 模式,捕获 AssertionError 应该与正确的异常处理结合,以确保失败能被看见,而不是让整个流程直接崩掉。

3. 在非关键路径中优雅降级

def generate_report(data):
    report = {"data": data, "charts": []}
 
    try:
        assert len(data) >= 10, "Not enough data for chart"
        chart = create_chart(data)
        report["charts"].append(chart)
    except AssertionError:
        report["charts_note"] = "Insufficient data for visualization"
 
    return report

什么时候不应该捕获 AssertionError

不要通过捕获 AssertionError 来悄悄抑制 bug。断言的目的就是让 bug 大声暴露出来:

# WRONG: Silencing assertions defeats their purpose
try:
    assert user.is_valid()
except AssertionError:
    pass  # Who cares?

真实世界中的示例

数据管道验证

在数据处理流水线中,断言极其有价值,因为数据转换必须保持某些性质:

import pandas as pd
 
def clean_sales_data(df):
    """Clean and validate sales data."""
    assert isinstance(df, pd.DataFrame), f"Expected DataFrame, got {type(df)}"
    assert len(df) > 0, "DataFrame is empty"
 
    initial_rows = len(df)
 
    # Remove duplicates
    df = df.drop_duplicates(subset=["order_id"])
    assert len(df) > 0, "All rows were duplicates"
 
    # Validate required columns
    required = {"order_id", "product", "quantity", "price"}
    assert required.issubset(df.columns), (
        f"Missing columns: {required - set(df.columns)}"
    )
 
    # Clean numeric columns
    df["quantity"] = pd.to_numeric(df["quantity"], errors="coerce")
    df["price"] = pd.to_numeric(df["price"], errors="coerce")
 
    # Drop rows with invalid numbers
    df = df.dropna(subset=["quantity", "price"])
 
    # Postcondition: all prices and quantities are positive
    assert (df["price"] > 0).all(), (
        f"Found {(df['price'] <= 0).sum()} non-positive prices"
    )
    assert (df["quantity"] > 0).all(), (
        f"Found {(df['quantity'] <= 0).sum()} non-positive quantities"
    )
 
    # Calculate total
    df["total"] = df["quantity"] * df["price"]
    assert (df["total"] > 0).all(), "Totals must be positive"
 
    print(f"Cleaned {initial_rows} -> {len(df)} rows")
    return df

在数据科学工作流中处理 DataFrame 时,PyGWalker (opens in a new tab) 可以让你把经过验证的 DataFrame 转换为交互式可视化,以便进一步探索——当你的管道断言确认数据已清洗并适合分析后,这就是一个很自然的下一步。

API 响应检查

import requests
 
def fetch_user_profile(user_id):
    """Fetch user profile from API with defensive assertions."""
    response = requests.get(f"https://api.example.com/users/{user_id}")
 
    # Use raise for external data validation (not assert!)
    if response.status_code != 200:
        raise RuntimeError(f"API error: {response.status_code}")
 
    data = response.json()
    if "user" not in data:
        raise ValueError("API response missing 'user' field")
 
    user = data["user"]
 
    # Use assert for internal invariants -- things that should
    # always be true if the API contract is correct
    assert "id" in user, "API contract violation: user missing 'id'"
    assert user["id"] == user_id, (
        f"API returned wrong user: requested {user_id}, got {user['id']}"
    )
 
    return user

注意这里的区别:raise 处理预期错误条件(网络问题、状态码错误)。assert 则捕获那些表明 API 本身有 bug,或者你的代码对 API 的假设错误的问题。

机器学习模型健壮性检查

import numpy as np
 
def train_model(X_train, y_train, X_test, y_test):
    """Train a model with sanity checks at each stage."""
 
    # Data shape assertions
    assert X_train.ndim == 2, f"X_train must be 2D, got {X_train.ndim}D"
    assert y_train.ndim == 1, f"y_train must be 1D, got {y_train.ndim}D"
    assert X_train.shape[0] == y_train.shape[0], (
        f"Sample count mismatch: X={X_train.shape[0]}, y={y_train.shape[0]}"
    )
    assert X_train.shape[1] == X_test.shape[1], (
        f"Feature count mismatch: train={X_train.shape[1]}, test={X_test.shape[1]}"
    )
 
    # No NaN or Inf in data
    assert not np.isnan(X_train).any(), "X_train contains NaN values"
    assert not np.isinf(X_train).any(), "X_train contains Inf values"
 
    # Labels are valid
    unique_labels = np.unique(y_train)
    assert len(unique_labels) >= 2, (
        f"Need at least 2 classes, got {len(unique_labels)}"
    )
 
    # Train the model
    model = fit(X_train, y_train)
 
    # Predictions sanity check
    predictions = model.predict(X_test)
    assert predictions.shape == y_test.shape, (
        f"Prediction shape {predictions.shape} != target shape {y_test.shape}"
    )
    assert set(predictions).issubset(set(unique_labels)), (
        f"Model predicted unknown labels: {set(predictions) - set(unique_labels)}"
    )
 
    # Accuracy sanity check (should be better than random)
    accuracy = np.mean(predictions == y_test)
    random_baseline = 1 / len(unique_labels)
    assert accuracy > random_baseline * 0.8, (
        f"Accuracy {accuracy:.2%} is worse than random ({random_baseline:.2%})"
    )
 
    return model

状态机转换

class OrderStateMachine:
    VALID_TRANSITIONS = {
        "created": {"confirmed", "cancelled"},
        "confirmed": {"shipped", "cancelled"},
        "shipped": {"delivered", "returned"},
        "delivered": {"returned"},
        "cancelled": set(),
        "returned": set(),
    }
 
    def __init__(self):
        self.state = "created"
        self.history = ["created"]
 
    def transition(self, new_state):
        assert new_state in self.VALID_TRANSITIONS.get(self.state, set()), (
            f"Invalid transition: {self.state} -> {new_state}. "
            f"Valid transitions: {self.VALID_TRANSITIONS[self.state]}"
        )
 
        self.state = new_state
        self.history.append(new_state)
 
        # Invariant: history should always start with "created"
        assert self.history[0] == "created", "History corrupted"
        # Invariant: current state should match last history entry
        assert self.state == self.history[-1], "State/history mismatch"

性能考虑

断言有多大开销?

断言的开销很小,但确实存在。每次执行断言时,条件表达式都会被求值。对于像 assert x > 0 这样的简单检查,这几乎可以忽略不计。对于昂贵的检查,成本就可能累积起来:

import time
 
data = list(range(1_000_000))
 
# Fast assertion: O(1)
start = time.perf_counter()
for _ in range(10_000):
    assert len(data) > 0
fast_time = time.perf_counter() - start
print(f"Simple assertion: {fast_time:.4f}s")
 
# Slow assertion: O(n) -- checks every element
start = time.perf_counter()
for _ in range(100):
    assert all(isinstance(x, int) for x in data)
slow_time = time.perf_counter() - start
print(f"Expensive assertion: {slow_time:.4f}s")

昂贵断言的策略

如果某个断言在紧密循环中太慢,可以考虑下面几种方式:

1. 只检查数据样本,而不是整个数据集

import random
 
def process_large_dataset(records):
    # Check a random sample instead of all records
    sample = random.sample(records, min(100, len(records)))
    assert all(is_valid(r) for r in sample), "Invalid records found in sample"
    # Process all records...

2. 使用 __debug__ 标志进行条件性的昂贵检查

def matrix_multiply(a, b):
    if __debug__:
        # This entire block is removed with python -O
        assert a.shape[1] == b.shape[0], (
            f"Incompatible shapes: {a.shape} x {b.shape}"
        )
        # Expensive but helpful during development
        assert not np.isnan(a).any(), "Matrix a contains NaN"
        assert not np.isnan(b).any(), "Matrix b contains NaN"
 
    return a @ b

3. 只在边界处断言,而不是在内层循环里断言

def process_batch(items):
    # Assert once at the boundary
    assert all(item.is_valid() for item in items), "Invalid items in batch"
 
    # Inner loop without assertions for performance
    results = []
    for item in items:
        # No assertions here -- we validated above
        result = transform(item)
        results.append(result)
 
    # Assert once at the output boundary
    assert len(results) == len(items), "Result count mismatch"
    return results

最佳实践总结

以下是高效使用断言的关键原则:

1. 始终包含消息。 assert x > 0 失败时不会告诉你太多。assert x > 0, f"Expected positive value, got {x}" 才能告诉你全部信息。

2. 绝不要用 assert 做输入验证。 用户输入、文件内容、API 响应和数据库查询都可能有问题。用 if/raise 来验证它们。

3. 将 assert 用于内部不变量。 也就是如果代码正确就一定成立的事情:函数前置条件、后置条件、循环不变量、类不变量。

4. 绝不要在 assert 里放副作用。 assert items.pop() == expected 里的 pop 会从列表中移除元素——但只有在断言启用时才会发生。

5. 在开发过程中尽量多用 assert。-O 禁用时它们几乎不影响性能,却能在调试时节省大量时间。

6. 让断言消息可操作。 包含实际值、期望值,以及足够的上下文来理解问题。

7. 测试你的断言。 编写测试来验证断言确实能捕捉到它们应该捕捉的 bug。

import pytest
 
def test_transfer_rejects_negative_amount():
    with pytest.raises(AssertionError, match="positive"):
        transfer_money(account_a, account_b, amount=-100)

FAQ

结论

Python 的 assert 语句是一个轻量而强大的防御式编程工具。它把隐含的假设变成显式、可执行的检查,在失败发生的位置直接捕获 bug,而不是让坏数据在代码中一路传播。使用得当时,断言可以加快调试、让代码更易读,并让不变量自解释。

核心规则很简单:将 assert 用于内部不变量和开发者假设,将 raise 用于输入验证和预期错误条件,始终包含描述性消息,并且绝不要在 assert 表达式中放副作用。在测试中,pytest 和 unittest 都高度依赖断言来验证预期行为。

对于数据科学和分析工作流,断言可以与 PyGWalker (opens in a new tab) 这样的工具自然配合,在可视化前验证 DataFrame;也可以与 RunCell (opens in a new tab) 这样的交互式环境结合,在 Jupyter notebooks 中迭代构建并测试带断言保护的数据管道。

掌握这些模式后,你的调试会更快,代码会更健壮,测试也会更具表现力。

相关指南

📚