Skip to content

Python 生成器:yield、生成器表达式与惰性求值完全指南

Updated on

处理一个 10GB 的日志文件,或流式读取数百万条数据库记录,足以让你的 Python 应用“跪下”。传统做法是一次性把所有数据加载进内存,这会带来性能瓶颈、内存错误以及用户的挫败感。这正是 Python 生成器变得不可或缺的地方——它通过“按需生成值”而不是“预先存储所有值”,让你用极小的内存占用处理海量数据集。

📚

什么是 Python 生成器,以及为什么它们很重要

生成器是一类特殊函数:它们会随着时间推移产出一系列值,而不是一次性计算并返回所有结果。与使用 return 只返回单个结果的普通函数不同,生成器使用 yield 关键字来产出多个值,并且在每次产出后暂停执行,直到下一次请求下一个值时再继续。

生成器的根本优势是 惰性求值(lazy evaluation)——只有在需要时才生成值。这带来两个关键收益:

  1. 内存效率:生成器不会把整个序列存进内存。一个生成十亿个数的生成器,和一个生成十个数的生成器占用的内存几乎相同。
  2. 性能:可以在产出第一个值后立刻开始处理,而不必等待整个数据集准备完成。

下面用一个简单对比来说明差异:

# Traditional approach - loads entire list into memory
def get_squares_list(n):
    result = []
    for i in range(n):
        result.append(i * i)
    return result
 
# Generator approach - produces values one at a time
def get_squares_generator(n):
    for i in range(n):
        yield i * i
 
# Memory impact comparison
import sys
 
# List approach
squares_list = get_squares_list(1000000)
print(f"List memory: {sys.getsizeof(squares_list):,} bytes")  # ~8,000,000 bytes
 
# Generator approach
squares_gen = get_squares_generator(1000000)
print(f"Generator memory: {sys.getsizeof(squares_gen):,} bytes")  # ~112 bytes

内存差距非常惊人——在这个示例中,生成器比列表少用 99.999% 的内存。随着数据集更大,这种差异会被进一步放大。

yield 关键字:生成器函数的核心

yield 关键字会把一个普通函数“变身”为生成器函数。当 Python 遇到 yield 时,它会返回一个生成器对象,而不是立刻执行函数并返回结果。

def countdown(n):
    print(f"Starting countdown from {n}")
    while n > 0:
        yield n
        n -= 1
    print("Countdown complete!")
 
# Creating the generator doesn't execute the function
gen = countdown(3)
print(type(gen))  # <class 'generator'>
 
# Values are produced on-demand
print(next(gen))  # Starting countdown from 3 -> 3
print(next(gen))  # 2
print(next(gen))  # 1
# next(gen)  # Countdown complete! -> Raises StopIteration

需要理解的关键行为:

  • 每次执行到 yield 都会暂停,下次调用时从该位置继续
  • 局部变量会在多次 yield 之间保持状态
  • 当生成器函数结束(没有更多值)时,会抛出 StopIteration 异常

一个生成器中可以出现多个 yield

def data_pipeline():
    # Phase 1: Loading
    yield "Loading data..."
 
    # Phase 2: Processing
    yield "Processing records..."
 
    # Phase 3: Validation
    yield "Validating results..."
 
    # Phase 4: Complete
    yield "Pipeline complete!"
 
for status in data_pipeline():
    print(status)

生成器协议:理解 iter() 与 next()

生成器通过两个特殊方法实现迭代器协议:

  • __iter__():返回迭代器对象自身(也就是生成器本身)
  • __next__():返回生成器的下一个值

因此,生成器非常适合用于 for 循环以及其它迭代场景。理解这一协议有助于你从底层机制上弄清生成器如何工作:

def simple_gen():
    yield 1
    yield 2
    yield 3
 
gen = simple_gen()
 
# These are equivalent
print(gen.__next__())  # 1
print(next(gen))       # 2
 
# for loops call __next__() automatically until StopIteration
for value in simple_gen():
    print(value)  # 1, 2, 3

你也可以手动实现迭代器协议,以创建类似生成器的行为:

class CountDown:
    def __init__(self, start):
        self.current = start
 
    def __iter__(self):
        return self
 
    def __next__(self):
        if self.current <= 0:
            raise StopIteration
        self.current -= 1
        return self.current + 1
 
# Behaves like a generator
for num in CountDown(3):
    print(num)  # 3, 2, 1

不过,相比手写迭代器类,生成器函数更简洁、更可读。

生成器表达式 vs 列表推导式

生成器表达式提供了一种创建生成器的简洁语法,类似列表推导式,但用圆括号而不是方括号:

# List comprehension - creates entire list in memory
squares_list = [x * x for x in range(10)]
print(type(squares_list))  # <class 'list'>
print(squares_list)  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
 
# Generator expression - creates generator object
squares_gen = (x * x for x in range(10))
print(type(squares_gen))  # <class 'generator'>
print(squares_gen)  # <generator object at 0x...>
 
# Consume the generator
print(list(squares_gen))  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

语法对比:

特性List ComprehensionGenerator Expression
语法[expr for item in iterable](expr for item in iterable)
返回List objectGenerator object
内存存储所有值按需生成
速度小数据更快大数据更快
可复用是(可多次迭代)否(一次迭代后耗尽)

展示内存差异的实际示例:

import sys
 
# List comprehension for 1 million numbers
list_comp = [x for x in range(1000000)]
print(f"List comprehension: {sys.getsizeof(list_comp):,} bytes")
 
# Generator expression for the same range
gen_exp = (x for x in range(1000000))
print(f"Generator expression: {sys.getsizeof(gen_exp):,} bytes")
 
# Output:
# List comprehension: 8,000,056 bytes
# Generator expression: 112 bytes

当你只需要遍历一次并希望尽量减少内存使用时,生成器表达式非常理想。

yield from:委托给子生成器

yield from 用于将迭代委托给子生成器或其它可迭代对象。它避免了手写循环逐个 yield,而是自动完成委托:

# Without yield from
def get_numbers_manual():
    for i in range(3):
        yield i
    for i in range(10, 13):
        yield i
 
# With yield from
def get_numbers_delegated():
    yield from range(3)
    yield from range(10, 13)
 
print(list(get_numbers_manual()))      # [0, 1, 2, 10, 11, 12]
print(list(get_numbers_delegated()))   # [0, 1, 2, 10, 11, 12]

它在扁平化嵌套结构时尤其有用:

def flatten(nested_list):
    for item in nested_list:
        if isinstance(item, list):
            yield from flatten(item)  # Recursive delegation
        else:
            yield item
 
nested = [1, [2, 3, [4, 5]], 6, [7, [8, 9]]]
print(list(flatten(nested)))  # [1, 2, 3, 4, 5, 6, 7, 8, 9]

yield from 还能正确处理子生成器的异常与返回值,因此在复杂生成器流水线中非常关键。

进阶:send() 与 throw() 方法

生成器不仅能“产出值”,还可以通过 send()throw() 接收外部输入并处理异常,从而实现协程风格的双向通信。

使用 send() 向生成器发送值

def running_average():
    total = 0
    count = 0
    average = None
 
    while True:
        value = yield average  # Yield current average, receive new value
        total += value
        count += 1
        average = total / count
 
# Create generator
avg = running_average()
next(avg)  # Prime the generator (advance to first yield)
 
# Send values and receive running averages
print(avg.send(10))   # 10.0
print(avg.send(20))   # 15.0
print(avg.send(30))   # 20.0
print(avg.send(40))   # 25.0

send() 既会把值送入生成器(成为 yield 表达式的结果),也会让生成器继续执行直到下一个 yield

使用 throw() 注入异常

def error_handling_gen():
    try:
        while True:
            value = yield
            print(f"Received: {value}")
    except ValueError as e:
        print(f"Caught ValueError: {e}")
        yield "Recovered from error"
    except GeneratorExit:
        print("Generator is closing")
 
gen = error_handling_gen()
next(gen)  # Prime the generator
 
gen.send(10)              # Received: 10
gen.send(20)              # Received: 20
result = gen.throw(ValueError, "Invalid value")  # Caught ValueError: Invalid value
print(result)             # Recovered from error
gen.close()               # Generator is closing

这些高级特性对于实现状态机、协程以及复杂的异步模式非常有用。

无限生成器:无穷序列

生成器非常擅长产生无限序列,因为它无需将整个序列实例化到内存中:

# Infinite counter
def count_from(start=0, step=1):
    current = start
    while True:
        yield current
        current += step
 
# Fibonacci sequence
def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b
 
# Cycling through a sequence
def cycle(iterable):
    saved = []
    for item in iterable:
        yield item
        saved.append(item)
    while saved:
        for item in saved:
            yield item
 
# Usage examples
counter = count_from(10, 2)
for _ in range(5):
    print(next(counter))  # 10, 12, 14, 16, 18
 
fib = fibonacci()
print([next(fib) for _ in range(10)])  # [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
 
colors = cycle(['red', 'green', 'blue'])
print([next(colors) for _ in range(8)])  # ['red', 'green', 'blue', 'red', 'green', 'blue', 'red', 'green']

无限生成器尤其适用于事件流、持续监控以及需要保存状态的迭代模式。

链式生成器:构建数据处理流水线

生成器最强大的模式之一是把它们串联起来,构建高效的数据处理流水线。每个阶段都以惰性方式处理数据,并把结果传递到下一阶段,不会存储中间结果:

# Stage 1: Read lines from a file (generator)
def read_log_file(filename):
    with open(filename, 'r') as f:
        for line in f:
            yield line.strip()
 
# Stage 2: Filter lines containing 'ERROR'
def filter_errors(lines):
    for line in lines:
        if 'ERROR' in line:
            yield line
 
# Stage 3: Extract timestamp and message
def parse_error_lines(lines):
    for line in lines:
        parts = line.split(' - ')
        if len(parts) >= 2:
            yield {'timestamp': parts[0], 'message': parts[1]}
 
# Stage 4: Count errors by hour
def group_by_hour(errors):
    from collections import defaultdict
    hourly_counts = defaultdict(int)
 
    for error in errors:
        hour = error['timestamp'][:13]  # Extract hour portion
        hourly_counts[hour] += 1
 
    return hourly_counts
 
# Build pipeline
log_lines = read_log_file('app.log')
error_lines = filter_errors(log_lines)
parsed_errors = parse_error_lines(error_lines)
results = group_by_hour(parsed_errors)
 
print(results)

这条流水线可以用极小内存处理可能非常巨大的日志文件——直到最终聚合阶段之前,内存里任何时刻都只有一行数据。

另一个数据转换示例:

# Pipeline: numbers -> square -> filter evens -> sum
def square_numbers(numbers):
    for n in numbers:
        yield n * n
 
def filter_even(numbers):
    for n in numbers:
        if n % 2 == 0:
            yield n
 
# Chain the pipeline
numbers = range(1, 11)  # 1-10
squared = square_numbers(numbers)
evens = filter_even(squared)
result = sum(evens)  # Only even squares
 
print(result)  # 220 (4 + 16 + 36 + 64 + 100)

内存对比:Generator vs List 基准测试

下面做一个贴近真实场景的内存与性能基准测试,量化生成器的收益:

import sys
import time
import tracemalloc
 
def process_with_list(n):
    """Traditional approach using lists"""
    tracemalloc.start()
    start_time = time.time()
 
    # Create list of squares
    squares = [x * x for x in range(n)]
 
    # Filter even squares
    even_squares = [x for x in squares if x % 2 == 0]
 
    # Sum results
    result = sum(even_squares)
 
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    elapsed = time.time() - start_time
 
    return result, peak / 1024 / 1024, elapsed  # Convert to MB
 
def process_with_generator(n):
    """Generator approach"""
    tracemalloc.start()
    start_time = time.time()
 
    # Generator pipeline
    squares = (x * x for x in range(n))
    even_squares = (x for x in squares if x % 2 == 0)
    result = sum(even_squares)
 
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    elapsed = time.time() - start_time
 
    return result, peak / 1024 / 1024, elapsed
 
# Benchmark with 1 million numbers
n = 1000000
 
list_result, list_memory, list_time = process_with_list(n)
gen_result, gen_memory, gen_time = process_with_generator(n)
 
print(f"Results match: {list_result == gen_result}")
print(f"\nList approach:")
print(f"  Memory: {list_memory:.2f} MB")
print(f"  Time: {list_time:.4f} seconds")
print(f"\nGenerator approach:")
print(f"  Memory: {gen_memory:.2f} MB")
print(f"  Time: {gen_time:.4f} seconds")
print(f"\nMemory savings: {((list_memory - gen_memory) / list_memory * 100):.1f}%")

典型输出:

Results match: True

List approach:
  Memory: 36.21 MB
  Time: 0.0892 seconds

Generator approach:
  Memory: 0.12 MB
  Time: 0.0624 seconds

Memory savings: 99.7%

生成器方案少用 99.7% 内存,并且快 30%——这是一种会随着数据规模增长而更加显著的改进。

itertools 模块:生成器工具集

Python 的 itertools 模块提供了一组强大的、基于生成器的高效迭代工具。这些工具用 C 实现并经过高度优化:

必备 itertools 函数

import itertools
 
# chain - concatenate multiple iterables
combined = itertools.chain([1, 2], [3, 4], [5, 6])
print(list(combined))  # [1, 2, 3, 4, 5, 6]
 
# islice - slice an iterable (like list slicing but for generators)
numbers = itertools.count()  # Infinite counter: 0, 1, 2, 3...
first_ten = itertools.islice(numbers, 10)
print(list(first_ten))  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
 
# count - infinite counter with start and step
counter = itertools.count(start=10, step=2)
print([next(counter) for _ in range(5)])  # [10, 12, 14, 16, 18]
 
# cycle - infinite repetition of an iterable
colors = itertools.cycle(['red', 'green', 'blue'])
print([next(colors) for _ in range(7)])  # ['red', 'green', 'blue', 'red', 'green', 'blue', 'red']
 
# accumulate - cumulative sums or other operations
numbers = [1, 2, 3, 4, 5]
cumulative = itertools.accumulate(numbers)
print(list(cumulative))  # [1, 3, 6, 10, 15]
 
# accumulate with custom function
import operator
products = itertools.accumulate(numbers, operator.mul)
print(list(products))  # [1, 2, 6, 24, 120]
 
# groupby - group consecutive elements by key
data = [('A', 1), ('A', 2), ('B', 3), ('B', 4), ('C', 5)]
for key, group in itertools.groupby(data, key=lambda x: x[0]):
    print(f"{key}: {list(group)}")
# A: [('A', 1), ('A', 2)]
# B: [('B', 3), ('B', 4)]
# C: [('C', 5)]

实用的 itertools 组合

# Paginating results with islice
def paginate(iterable, page_size):
    iterator = iter(iterable)
    while True:
        page = list(itertools.islice(iterator, page_size))
        if not page:
            break
        yield page
 
# Usage
data = range(25)
for page_num, page in enumerate(paginate(data, 10), 1):
    print(f"Page {page_num}: {page}")
# Page 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# Page 2: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
# Page 3: [20, 21, 22, 23, 24]
 
# Windowed iteration (sliding window)
def window(iterable, size):
    it = iter(iterable)
    win = list(itertools.islice(it, size))
    if len(win) == size:
        yield tuple(win)
    for item in it:
        win = win[1:] + [item]
        yield tuple(win)
 
print(list(window([1, 2, 3, 4, 5], 3)))
# [(1, 2, 3), (2, 3, 4), (3, 4, 5)]

真实世界用例

逐行读取大文件

def process_large_csv(filename):
    """Process a multi-GB CSV file efficiently"""
    with open(filename, 'r') as f:
        # Skip header
        next(f)
 
        for line in f:
            # Parse and yield record
            fields = line.strip().split(',')
            yield {
                'user_id': fields[0],
                'action': fields[1],
                'timestamp': fields[2]
            }
 
# Process millions of records with minimal memory
for record in process_large_csv('user_events.csv'):
    # Process one record at a time
    if record['action'] == 'purchase':
        print(f"Purchase by user {record['user_id']}")

流式数据处理

def stream_api_data(url, batch_size=100):
    """Stream paginated API data without loading all results"""
    offset = 0
 
    while True:
        response = requests.get(url, params={'offset': offset, 'limit': batch_size})
        data = response.json()
 
        if not data:
            break
 
        for item in data:
            yield item
 
        offset += batch_size
 
# Process unlimited API results
for item in stream_api_data('https://api.example.com/records'):
    process_item(item)

数据库查询结果迭代

def fetch_users_batch(cursor, batch_size=1000):
    """Fetch database records in batches without loading all into memory"""
    while True:
        results = cursor.fetchmany(batch_size)
        if not results:
            break
        for row in results:
            yield row
 
# Database query
cursor.execute("SELECT * FROM users WHERE active = 1")
 
# Process millions of users efficiently
for user in fetch_users_batch(cursor):
    send_email(user['email'], generate_report(user))

ETL 流水线示例

# Extract: Read from source
def extract_from_csv(filename):
    with open(filename, 'r') as f:
        for line in f:
            yield line.strip().split(',')
 
# Transform: Clean and convert data
def transform_records(records):
    for record in records:
        yield {
            'id': int(record[0]),
            'name': record[1].title(),
            'email': record[2].lower(),
            'age': int(record[3]) if record[3] else None
        }
 
# Load: Write to database
def load_to_database(records, db_connection):
    for record in records:
        db_connection.execute(
            "INSERT INTO users VALUES (?, ?, ?, ?)",
            (record['id'], record['name'], record['email'], record['age'])
        )
        yield record  # Pass through for logging
 
# Build ETL pipeline
raw_data = extract_from_csv('users.csv')
transformed = transform_records(raw_data)
loaded = load_to_database(transformed, db_conn)
 
# Execute pipeline and count processed records
processed_count = sum(1 for _ in loaded)
print(f"Processed {processed_count} records")

生成器最佳实践与常见坑

最佳实践

  1. 简单场景用生成器表达式

    # Simple transformation - use generator expression
    squares = (x * x for x in range(1000))
     
    # Complex logic - use generator function
    def complex_processing(data):
        for item in data:
            # Multi-step processing
            result = step1(item)
            result = step2(result)
            if validate(result):
                yield result
  2. 用链式生成器构建数据流水线

    # Each stage processes lazily
    data = read_source()
    filtered = filter_stage(data)
    transformed = transform_stage(filtered)
    results = aggregate_stage(transformed)
  3. yield from 做委托

    def process_all_files(directory):
        for filename in os.listdir(directory):
            yield from process_file(filename)

常见坑

  1. 生成器迭代一次就会耗尽

    gen = (x for x in range(3))
    print(list(gen))  # [0, 1, 2]
    print(list(gen))  # [] - exhausted!
     
    # Solution: Convert to list or recreate generator
    data = list(gen)  # If data fits in memory
    # OR
    gen = (x for x in range(3))  # Recreate
  2. 生成器不支持 len() 或索引访问

    gen = (x for x in range(10))
    # len(gen)  # TypeError
    # gen[5]    # TypeError
     
    # Solution: Convert to list if you need these operations
    items = list(gen)
    print(len(items))
    print(items[5])
  3. 注意生成器作用域与闭包捕获

    # Wrong - all generators will use final value of i
    generators = [lambda: i for i in range(3)]
    print([g() for g in generators])  # [2, 2, 2]
     
    # Correct - capture i in default argument
    generators = [lambda i=i: i for i in range(3)]
    print([g() for g in generators])  # [0, 1, 2]
  4. 生成器链中的异常处理

    def stage1():
        for i in range(5):
            if i == 3:
                raise ValueError("Error in stage1")
            yield i
     
    def stage2(data):
        try:
            for item in data:
                yield item * 2
        except ValueError as e:
            print(f"Caught: {e}")
            yield -1  # Error marker
     
    # Exception is caught in stage2
    for result in stage2(stage1()):
        print(result)

对比:Generators vs Lists vs Iterators vs map/filter

FeatureGeneratorsListsIteratorsmap/filter
Memory usageMinimal (lazy)Full datasetMinimal (lazy)Minimal (lazy)
Creation speedInstantDepends on sizeInstantInstant
ReusableNoYesNoNo
IndexableNoYesNoNo
len() supportNoYesNoNo
ModificationRead-onlyMutableRead-onlyRead-only
Infinite sequencesYesNoYesYes
Syntaxyield or ()[]iter()map(), filter()
Best forLarge datasets, pipelinesSmall datasets, random accessProtocol implementationFunctional transformations

示例对比:

# All produce same results but with different characteristics
data = range(1000000)
 
# Generator - memory efficient, not reusable
gen = (x * 2 for x in data)
 
# List - memory intensive, reusable, indexable
lst = [x * 2 for x in data]
 
# map - memory efficient, functional style
mapped = map(lambda x: x * 2, data)
 
# Iterator - explicit protocol implementation
class Doubler:
    def __init__(self, data):
        self.data = iter(data)
 
    def __iter__(self):
        return self
 
    def __next__(self):
        return next(self.data) * 2
 
iterator = Doubler(data)

在 Jupyter 中试验生成器

在探索生成器模式与性能特征时,使用交互式 notebook 环境能显著加速学习。RunCell (opens in a new tab) 将 AI 辅助能力直接带入 Jupyter notebooks,非常适合数据科学家试验基于生成器的数据处理流水线。

使用 RunCell,你可以:

  • 快速原型化生成器函数并测试内存特性
  • 用真实数据集对生成器与列表进行性能基准对比
  • 交互式构建与调试复杂生成器流水线
  • 获得优化生成器 ETL 工作流的 AI 建议

下面是在 notebook 里探索生成器的方式:

# Cell 1: Define generator pipeline
def read_data():
    for i in range(1000000):
        yield {'id': i, 'value': i * 2}
 
def filter_large(records):
    for record in records:
        if record['value'] > 1000:
            yield record
 
def transform(records):
    for record in records:
        record['squared'] = record['value'] ** 2
        yield record
 
# Cell 2: Execute pipeline and measure
import time
start = time.time()
 
pipeline = transform(filter_large(read_data()))
results = list(itertools.islice(pipeline, 100))  # Take first 100
 
print(f"Time: {time.time() - start:.4f}s")
print(f"Results: {len(results)}")
 
# Cell 3: Visualize with PyGWalker
import pygwalker as pyg
pyg.walk(results)

FAQ

结论

Python 生成器代表了一种从“急切求值(eager)”到“惰性求值(lazy)”的根本转变,使你能够以更高的内存效率处理从数千到数十亿条记录的数据集。通过理解 yield、生成器表达式、迭代器协议,以及 send()yield from 等高级特性,你可以构建可轻松扩展的复杂数据处理流水线。

需要记住的关键要点:

  • 生成器使用惰性求值来最小化内存占用——相较列表常常节省 99%+ 内存
  • 简单变换用生成器表达式,复杂逻辑用生成器函数
  • 通过链式生成器构建内存高效的数据处理流水线
  • 利用 itertools 获得强大的、基于生成器的迭代工具
  • 大数据与单次遍历选择生成器;小数据且需要随机访问则选择列表

无论你是在处理超大日志文件、流式消费 API 数据,还是构建 ETL 流水线,生成器都能提供生产级数据处理所需的性能与内存效率。掌握这些模式后,你将写出优雅且高效、能够应对任意规模数据集的 Python 代码。

📚