Python defaultdict:用默认值简化字典操作
Updated on
每个Python开发者都遇到过这个问题:你编写了一个干净的循环来使用字典对项目进行分组或计数,运行代码后,由于某个键尚不存在,KeyError导致整个脚本崩溃。标准的解决方法是到处添加if key in dict检查或try/except KeyError块。你用于分组十行数据的逻辑突然膨胀到二十行防御性的样板代码。
在大规模场景下情况更糟。当你构建图的邻接列表、聚合日志数据或在数百万条记录中统计词频时,这些保护子句会不断累积。它们拖慢你的开发速度,使代码更难审查,并且当你在某个分支中忘记检查时会引入微妙的bug。
Python的collections.defaultdict消除了这整类问题。它是一个字典子类,通过调用工厂函数自动提供缺失的值。不再有KeyError,不再有保护子句,不再有样板代码。
什么是defaultdict?
defaultdict是Python内置dict的子类。关键区别:当你访问一个不存在的键时,defaultdict会自动用默认值创建它,而不是抛出KeyError。
from collections import defaultdict
# 普通dict会抛出KeyError
regular = {}
# regular['missing'] # KeyError: 'missing'
# defaultdict自动创建值
dd = defaultdict(int)
dd['missing'] # 返回0,现在'missing'是一个键
print(dd) # defaultdict(<class 'int'>, {'missing': 0})构造函数接受一个工厂函数作为第一个参数。常见的工厂函数:
int-- 返回0list-- 返回[]set-- 返回set()str-- 返回""lambda: value-- 返回任何自定义默认值
defaultdict(int) -- 计数模式
最常见的用法。每个新键从0开始,因此可以直接递增。
from collections import defaultdict
words = ['apple', 'banana', 'apple', 'cherry', 'banana', 'apple']
# 不使用defaultdict
counts_regular = {}
for word in words:
if word in counts_regular:
counts_regular[word] += 1
else:
counts_regular[word] = 1
# 使用defaultdict(int) -- 简洁直接
counts = defaultdict(int)
for word in words:
counts[word] += 1
print(dict(counts))
# {'apple': 3, 'banana': 2, 'cherry': 1}defaultdict(list) -- 分组模式
将相关项目归为一组。每个新键以空列表开始。
from collections import defaultdict
students = [
('Math', 'Alice'),
('Science', 'Bob'),
('Math', 'Charlie'),
('Science', 'Diana'),
('Math', 'Eve'),
('History', 'Frank'),
]
groups = defaultdict(list)
for subject, student in students:
groups[subject].append(student)
for subject, names in groups.items():
print(f"{subject}: {', '.join(names)}")
# Math: Alice, Charlie, Eve
# Science: Bob, Diana
# History: Frank按多个字段分组记录
from collections import defaultdict
sales = [
{'region': 'East', 'product': 'Widget', 'amount': 100},
{'region': 'West', 'product': 'Gadget', 'amount': 200},
{'region': 'East', 'product': 'Widget', 'amount': 150},
{'region': 'West', 'product': 'Widget', 'amount': 300},
]
by_region_product = defaultdict(list)
for sale in sales:
key = (sale['region'], sale['product'])
by_region_product[key].append(sale['amount'])
for (region, product), amounts in by_region_product.items():
total = sum(amounts)
print(f"{region} - {product}: {amounts} (total: {total})")defaultdict(set) -- 唯一分组
自动按键收集唯一值。
from collections import defaultdict
edges = [
('Alice', 'Bob'), ('Alice', 'Charlie'),
('Bob', 'Alice'), ('Bob', 'Diana'),
('Alice', 'Bob'), # duplicate
]
connections = defaultdict(set)
for person, friend in edges:
connections[person].add(friend)
for person, friends in connections.items():
print(f"{person} is connected to: {friends}")
# Alice is connected to: {'Bob', 'Charlie'}
# Bob is connected to: {'Alice', 'Diana'}defaultdict(lambda: value) -- 自定义默认值
当内置类型不适用时,使用lambda返回任何默认值。
from collections import defaultdict
# 缺失条目的默认值为'N/A'
status = defaultdict(lambda: 'N/A')
status['server1'] = 'running'
status['server2'] = 'stopped'
print(status['server3']) # N/A
# 默认起始余额
accounts = defaultdict(lambda: 100.0)
accounts['alice'] += 50
accounts['bob'] -= 30
print(dict(accounts)) # {'alice': 150.0, 'bob': 70.0}带结构化值的默认字典
from collections import defaultdict
def default_profile():
return {'score': 0, 'level': 1, 'items': []}
profiles = defaultdict(default_profile)
profiles['player1']['score'] += 100
profiles['player1']['items'].append('sword')
profiles['player2']['level'] = 5
print(profiles['player1'])
# {'score': 100, 'level': 1, 'items': ['sword']}
print(profiles['player3'])
# {'score': 0, 'level': 1, 'items': []}嵌套defaultdict -- 树形结构
最强大的模式之一是递归使用defaultdict来创建自动生成的字典。
from collections import defaultdict
def tree():
return defaultdict(tree)
taxonomy = tree()
taxonomy['Animal']['Mammal']['Dog'] = 'Canis lupus familiaris'
taxonomy['Animal']['Mammal']['Cat'] = 'Felis catus'
taxonomy['Animal']['Bird']['Eagle'] = 'Aquila chrysaetos'
taxonomy['Plant']['Tree']['Oak'] = 'Quercus'
print(taxonomy['Animal']['Mammal']['Dog']) # Canis lupus familiaris多级聚合
from collections import defaultdict
sales_data = [
(2025, 'Q1', 'Widget', 500),
(2025, 'Q1', 'Gadget', 300),
(2025, 'Q2', 'Widget', 700),
(2026, 'Q1', 'Widget', 600),
]
report = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
for year, quarter, product, amount in sales_data:
report[year][quarter][product] += amount
print(report[2025]['Q1']['Widget']) # 500
print(report[2026]['Q1']['Widget']) # 600defaultdict vs dict.setdefault() vs get() -- 比较
| 特性 | defaultdict | dict.setdefault() | dict.get() |
|---|---|---|---|
| 需要导入 | 是(collections) | 否 | 否 |
| 自动创建键 | 是 | 是 | 否 |
| 访问时修改dict | 是 | 是 | 否 |
| 每次调用自定义默认值 | 否(全局工厂) | 是 | 是 |
| 性能(重复操作) | 最快 | 较慢(方法调用开销) | 最快(无变更) |
| 最适合 | 重复累积 | 一次性默认值 | 只读回退 |
何时使用:
defaultdict:在多次迭代中构建值(计数、分组)dict.setdefault():偶尔需要为特定键设置默认值dict.get():在不修改字典的情况下读取带回退的值
将defaultdict转换回普通dict
from collections import defaultdict
import json
def defaultdict_to_dict(d):
"""Recursively convert defaultdict to regular dict."""
if isinstance(d, defaultdict):
d = {k: defaultdict_to_dict(v) for k, v in d.items()}
return d
nested = defaultdict(lambda: defaultdict(int))
nested['x']['y'] = 10
nested['a']['b'] = 20
regular = defaultdict_to_dict(nested)
print(json.dumps(regular)) # {"x": {"y": 10}, "a": {"b": 20}}你也可以通过将默认工厂设置为None来禁用它:
dd = defaultdict(int)
dd['a'] += 1
dd.default_factory = None
# dd['missing'] # 现在会抛出KeyError实用示例
图的邻接列表
from collections import defaultdict, deque
edges = [('A', 'B'), ('A', 'C'), ('B', 'D'), ('C', 'D'), ('D', 'E')]
graph = defaultdict(list)
for src, dst in edges:
graph[src].append(dst)
graph[dst].append(src) # undirected graph
def bfs(graph, start):
visited = set()
queue = deque([start])
order = []
while queue:
node = queue.popleft()
if node not in visited:
visited.add(node)
order.append(node)
queue.extend(graph[node])
return order
print(bfs(graph, 'A')) # ['A', 'B', 'C', 'D', 'E']文本搜索的倒排索引
from collections import defaultdict
documents = {
'doc1': 'python is a great programming language',
'doc2': 'data science uses python extensively',
'doc3': 'machine learning with python and data',
}
index = defaultdict(set)
for doc_id, text in documents.items():
for word in text.split():
index[word.lower()].add(doc_id)
def search(query):
return index.get(query.lower(), set())
print(search('python')) # {'doc1', 'doc2', 'doc3'}
print(search('data')) # {'doc2', 'doc3'}使用PyGWalker可视化分组数据
在使用defaultdict分组和聚合数据后,你通常想要可视化结果。PyGWalker (opens in a new tab)可以将你的pandas DataFrame直接在Jupyter中转换为交互式可视化界面:
from collections import defaultdict
import pandas as pd
import pygwalker as pyg
sales = [
('Electronics', 'Laptop', 1200),
('Electronics', 'Phone', 800),
('Clothing', 'Shirt', 45),
('Clothing', 'Jacket', 120),
]
totals = defaultdict(lambda: defaultdict(int))
for category, product, amount in sales:
totals[category][product] += amount
rows = []
for category, products in totals.items():
for product, total in products.items():
rows.append({'category': category, 'product': product, 'total': total})
df = pd.DataFrame(rows)
walker = pyg.walk(df)FAQ
Python中的defaultdict是什么?
defaultdict是collections中的字典子类,为缺失的键提供默认值。它不会抛出KeyError,而是调用工厂函数(如int、list或set)自动创建并存储默认值。
dict和defaultdict有什么区别?
唯一的功能区别是它们处理缺失键的方式。普通的dict会抛出KeyError。defaultdict会调用其default_factory函数来创建默认值。在其他所有方面,它们的行为完全相同。
什么时候应该使用defaultdict(list)和defaultdict(set)?
当你想要分组项目并保留重复项和插入顺序时,使用defaultdict(list)。当你想要只收集每个键的唯一项目时,使用defaultdict(set)。
可以将defaultdict序列化为JSON吗?
可以,但对于嵌套的defaultdict对象,需要先使用递归转换函数将其转换为普通的dict。你也可以设置default_factory = None来防止序列化前意外创建键。
如何创建嵌套的defaultdict?
定义一个递归工厂函数:def tree(): return defaultdict(tree)。对于更简单的两级嵌套,使用defaultdict(lambda: defaultdict(int))。
总结
Python的collections.defaultdict是标准库中最实用的工具之一。它将冗长且容易出错的字典累积模式转换为简洁的单行代码。使用defaultdict(int)进行计数,defaultdict(list)进行分组,defaultdict(set)进行唯一收集,嵌套的defaultdict处理层级数据。
关键要点:如果你发现自己在每次字典操作前都在写if key not in dict,就用defaultdict替换那个字典。你的代码会更短、更快、更容易维护。