Skip to content
话题
Seaborn
Seaborn 热图:Python 中创建热图的完整指南

Seaborn 热图:Python 中创建热图的完整指南

Updated on

你有一个包含数十个变量的数据集。你需要了解哪些特征相关联,模式隐藏在哪里,或者为什么你的机器学习模型一直表现异常。盯着一行行一列列的数字几乎告诉你不了什么。这正是 seaborn 热图 所解决的问题——它将密集的数值矩阵转换为你的大脑能在几秒钟内解析的彩色编码网格。

热图是数据科学中使用最广泛的可视化类型之一,Python 的 seaborn 库使创建热图变得非常简单。无论你是构建相关矩阵、分析混淆矩阵,还是可视化时间序列模式,sns.heatmap() 只需几行代码就能为你提供可发表的图表。

本指南将带你了解所有内容:基本语法、自定义选项、聚类热图等高级技术,以及完整的参数参考表。每个代码示例都可以直接复制粘贴使用。

📚

基本 Seaborn 热图语法

核心函数是 sns.heatmap()。它接受一个二维数据集——通常是 pandas DataFrame 或 NumPy 数组——并将其渲染为彩色网格。

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
 
# Create sample data
data = np.random.rand(5, 7)
ax = sns.heatmap(data)
plt.title("Basic Seaborn Heatmap")
plt.show()

这是最简单的热图。每个单元格的颜色代表其数值,seaborn 会自动在右侧添加一个颜色条。但实际应用几乎总是涉及更多配置,我们接下来将介绍这些内容。

创建相关矩阵热图

seaborn 热图最常见的用例是可视化相关矩阵。这告诉你数据集中每对变量之间的关联强度。

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
 
# Load a built-in dataset
df = sns.load_dataset("mpg").select_dtypes(include="number")
 
# Compute the correlation matrix
corr = df.corr()
 
# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
    corr,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    center=0,
    square=True,
    linewidths=0.5
)
plt.title("Correlation Matrix - MPG Dataset")
plt.tight_layout()
plt.show()

这里发生的关键事项:

  • annot=True 在每个单元格内打印相关系数。
  • fmt=".2f" 将这些数字格式化为两位小数。
  • cmap="coolwarm" 使用发散型调色板,其中负相关为蓝色,正相关为红色。
  • center=0 确保零相关映射到中性中点颜色。
  • square=True 强制每个单元格为完美的正方形,使视觉效果更清晰。

自定义选项

调色板(cmap 参数)

cmap 参数控制配色方案。选择正确的调色板取决于你的数据类型。

调色板类型示例名称最适合
顺序型"YlOrRd""Blues""viridis"从低到高的数据(计数、幅度)
发散型"coolwarm""RdBu_r""seismic"有意义中心点的数据(相关性、残差)
分类型"Set2""Paired"分类数据(热图不太常见)
感知均匀型"viridis""magma""inferno"确保可访问性和准确感知
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
 
data = np.random.rand(6, 6)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
 
cmaps = ["viridis", "coolwarm", "YlOrRd"]
for ax, cmap in zip(axes, cmaps):
    sns.heatmap(data, cmap=cmap, ax=ax, annot=True, fmt=".2f")
    ax.set_title(f'cmap="{cmap}"')
 
plt.tight_layout()
plt.show()

注释(annot 和 fmt 参数)

注释在每个单元格内显示数值。你可以控制它们的格式:

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
 
data = np.random.randint(0, 1000, size=(4, 5))
 
plt.figure(figsize=(8, 5))
sns.heatmap(
    data,
    annot=True,
    fmt="d",            # integer format
    cmap="Blues",
    annot_kws={"size": 14, "weight": "bold"}  # customize font
)
plt.title("Heatmap with Integer Annotations")
plt.show()

常见的 fmt 值:".2f" 表示两位小数,"d" 表示整数,".1%" 表示百分比,".1e" 表示科学计数法。

图形大小和纵横比

Seaborn 热图从 matplotlib 图形继承其大小。在调用 sns.heatmap() 之前设置它:

plt.figure(figsize=(12, 8))  # width=12, height=8 inches
sns.heatmap(data, cmap="viridis")
plt.show()

对于正方形单元格,将 square=True 传递给 sns.heatmap()。这会覆盖图形的纵横比,使每个单元格大小相等。

遮蔽上三角或下三角

相关矩阵是对称的。显示两个半部分是冗余的。使用 NumPy 的 triutril 来遮蔽一半:

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
 
df = sns.load_dataset("mpg").select_dtypes(include="number")
corr = df.corr()
 
# Create a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
 
plt.figure(figsize=(10, 8))
sns.heatmap(
    corr,
    mask=mask,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    center=0,
    square=True,
    linewidths=0.5
)
plt.title("Lower Triangle Correlation Heatmap")
plt.tight_layout()
plt.show()

mask 参数接受与数据形状相同的布尔数组。mask=True 的单元格将被隐藏。

热图参数参考表

参数描述默认值
data二维数据集(DataFrame、ndarray)必需
vmin / vmax颜色映射缩放的最小值/最大值从数据自动获取
cmap颜色映射名称或对象None(seaborn 默认值)
center颜色映射的中心值None
annot在单元格中显示数值False
fmt注释的格式字符串".2g"
annot_kws注释文本的关键字参数字典{}
linewidths分隔单元格的线条宽度0
linecolor单元格边框线的颜色"white"
cbar显示颜色条True
cbar_kws颜色条的关键字参数字典{}
square强制单元格为正方形False
mask布尔数组;True 单元格不显示None
xticklabelsx 轴刻度标签自动
yticklabelsy 轴刻度标签自动
ax要绘制的 Matplotlib Axes 对象当前 Axes

高级示例

使用 sns.clustermap 的聚类热图

当你想将相似的行和列分组在一起时,sns.clustermap() 应用层次聚类并自动重新排序轴:

import seaborn as sns
import matplotlib.pyplot as plt
 
df = sns.load_dataset("mpg").select_dtypes(include="number").dropna()
corr = df.corr()
 
g = sns.clustermap(
    corr,
    annot=True,
    fmt=".2f",
    cmap="vlag",
    center=0,
    linewidths=0.5,
    figsize=(8, 8),
    dendrogram_ratio=0.15
)
g.ax_heatmap.set_title("Clustered Correlation Heatmap", pad=60)
plt.show()

左侧和顶部的树状图显示聚类层次结构。最密切相关的变量被放置在一起,使模式更容易发现。

自定义颜色范围(vmin、vmax)

默认情况下,seaborn 将颜色缩放到数据的最小值和最大值。你可以覆盖此设置,以在同一比例上比较多个热图或突出显示特定范围:

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
data = np.random.uniform(-1, 1, size=(8, 8))
 
plt.figure(figsize=(8, 6))
sns.heatmap(
    data,
    vmin=-1,
    vmax=1,
    center=0,
    cmap="RdBu_r",
    annot=True,
    fmt=".2f"
)
plt.title("Heatmap with Fixed Color Range (-1 to 1)")
plt.show()

当绘制相关矩阵或归一化数据(理论范围已知)时,设置 vmin=-1vmax=1 特别有用。

混淆矩阵热图

另一个实际应用是可视化分类模型的混淆矩阵:

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
 
# Train a quick model
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.3, random_state=42
)
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
 
# Build and plot the confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(7, 5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=iris.target_names,
    yticklabels=iris.target_names
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix - Iris Classification")
plt.tight_layout()
plt.show()

对角线显示正确的预测。非对角线单元格显示模型将一个类别混淆为另一个类别的位置。

时间序列热图

热图也非常适合发现跨时间维度的模式。这是一个按星期几和小时显示活动的示例:

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
 
np.random.seed(0)
hours = list(range(24))
days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
data = pd.DataFrame(
    np.random.poisson(lam=20, size=(7, 24)),
    index=days,
    columns=hours
)
 
plt.figure(figsize=(14, 5))
sns.heatmap(data, cmap="YlOrRd", linewidths=0.3, annot=False)
plt.xlabel("Hour of Day")
plt.ylabel("Day of Week")
plt.title("Activity Heatmap by Day and Hour")
plt.tight_layout()
plt.show()

Seaborn 热图 vs Matplotlib imshow

你也可以使用 matplotlib.pyplot.imshow() 创建类似热图的可视化。以下是两者的比较:

特性sns.heatmap()plt.imshow()
内置颜色条是,自动手动(plt.colorbar()
单元格注释annot=True手动文本放置
接受 DataFrame是,自动标签否,需要数组
刻度标签处理从 DataFrame 索引/列自动手动设置
遮蔽支持内置 mask 参数使用 np.ma 手动
聚类通过 sns.clustermap()不内置
单元格间距linewidths 参数不直接支持
学习曲线常见用例较低更低级,更手动
自定义上限高(继承 matplotlib)非常高(完全控制)

**底线:**当你想要用最少的代码获得干净、标签良好的热图时,使用 sns.heatmap()。当你需要像素级控制或处理图像数据而不是表格数据时,回退到 imshow()

交互式替代方案:PyGWalker

静态热图对报告和论文很有用,但在探索性数据分析期间,你通常希望与数据进行交互——过滤、透视、向下钻取,并在图表类型之间切换,而无需重写代码。

PyGWalker (opens in a new tab)(Graphic Walker 的 Python 绑定)将任何 pandas DataFrame 转换为直接在 Jupyter Notebook 内的类似 Tableau 的交互式 UI。你可以拖放字段来构建热图、散点图、条形图等,根本不需要编写可视化代码。

pip install pygwalker
import pandas as pd
import pygwalker as pyg
 
df = pd.read_csv("your_data.csv")
walker = pyg.walk(df)

一旦交互式界面启动,你可以:

  • 将分类变量拖到行,另一个拖到列,将度量拖到颜色以创建热图。
  • 立即切换到其他图表类型(条形图、折线图、散点图)。
  • 无需编写额外代码即可过滤和聚合。

当你仍在探索要包含在最终 seaborn 热图中的变量时,这特别有用。在探索阶段使用 PyGWalker,然后使用 sns.heatmap() 锁定最终的静态可视化以进行共享。

常见问题

如何更改 seaborn 热图的大小?

在调用 sns.heatmap() 之前使用 plt.figure(figsize=(width, height)) 设置图形大小。例如,plt.figure(figsize=(12, 8)) 创建一个 12×8 英寸的图形。如果你正在使用子图,还可以传递 ax 参数。

如何在 seaborn 热图中注释值?

annot=True 传递给 sns.heatmap()。使用 fmt 参数控制数字格式(例如,fmt=".2f" 表示两位小数)。使用 annot_kws 自定义字体属性,例如:annot_kws={"size": 12, "weight": "bold"}

sns.heatmap 和 sns.clustermap 有什么区别?

sns.heatmap() 以原始行和列顺序显示数据。sns.clustermap() 应用层次聚类来重新排序行和列,使相似的值分组在一起,并添加树状图以显示聚类结构。

如何遮蔽相关热图的一半?

使用 NumPy 创建布尔遮罩。对于上三角:mask = np.triu(np.ones_like(corr, dtype=bool))。然后将 mask=mask 传递给 sns.heatmap()。对于下三角,请改用 np.tril()

我可以将 seaborn 热图保存为图像文件吗?

可以。创建热图后,在 plt.show() 之前调用 plt.savefig("heatmap.png", dpi=300, bbox_inches="tight")。Seaborn 热图支持所有 matplotlib 输出格式,包括 PNG、SVG、PDF 和 EPS。

结论

seaborn 热图是数据科学家可视化工具包中最通用的工具之一。从相关分析到混淆矩阵再到时间序列模式检测,sns.heatmap() 以简洁的语法和发表质量的输出处理所有这些。

从基础开始——传递你的数据并选择一个颜色映射。然后根据分析需求添加注释、遮蔽、自定义范围和聚类。在锁定最终可视化之前的探索阶段,像 PyGWalker (opens in a new tab) 这样的工具可以通过交互式拖放图表加快你的工作流程。

本指南中的代码示例都可以直接复制粘贴使用。选择最接近你用例的示例,换入你的数据,你将在不到一分钟的时间内获得清晰、信息丰富的热图。

📚