Skip to content

Matplotlib 散点图:plt.scatter() 完全指南

Updated on

散点图是探索两个数值变量之间关系的首选可视化方式。但要创建有效的散点图——能够在不变得杂乱的情况下揭示模式、聚类和异常值——需要的不仅仅是一个简单的 plt.scatter() 调用。你需要用于分类的颜色映射、用于第三个变量的大小编码、恰当的轴标签以及对重叠点的处理。

Matplotlib 的 plt.scatter() 通过丰富的参数集来处理所有这些需求。本指南涵盖了从基础散点图到气泡图、回归线和多面板散点矩阵等高级技术的所有内容。

📚

基础散点图

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100) * 0.5
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y)
plt.xlabel('X 值')
plt.ylabel('Y 值')
plt.title('基础散点图')
plt.show()

自定义标记

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y,
    s=100,              # 标记大小
    c='steelblue',      # 颜色
    marker='o',         # 标记形状
    alpha=0.7,          # 透明度
    edgecolors='black', # 边框颜色
    linewidths=0.5,     # 边框宽度
)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('自定义散点图')
plt.show()

常见标记形状

标记符号描述
'o'圆形默认
's'方形
'^'上三角
'D'菱形
'*'星形
'+'加号
'x'叉号
'.'小,适合密集数据

按类别着色

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 50
 
# 三个类别
categories = ['A', 'B', 'C']
colors = ['#e74c3c', '#3498db', '#2ecc71']
 
plt.figure(figsize=(8, 6))
for cat, color in zip(categories, colors):
    x = np.random.randn(n) + (categories.index(cat) * 2)
    y = np.random.randn(n) + (categories.index(cat) * 1.5)
    plt.scatter(x, y, c=color, label=cat, alpha=0.7, s=60, edgecolors='white')
 
plt.xlabel('特征 1')
plt.ylabel('特征 2')
plt.title('按类别着色的散点图')
plt.legend()
plt.show()

颜色映射(连续变量)

使用 c 参数配合数值数组和颜色映射表,将第三个变量编码为颜色:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(200)
y = np.random.randn(200)
values = x ** 2 + y ** 2  # 到原点的距离
 
plt.figure(figsize=(8, 6))
scatter = plt.scatter(x, y, c=values, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label='到原点的距离')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('带颜色映射的散点图')
plt.show()

常用颜色映射表

颜色映射类型最适合
'viridis'顺序默认,感知均匀
'plasma'顺序高对比度
'coolwarm'发散正/负值
'RdYlGn'发散好/坏范围
'Set1'定性分类数据

大小编码(气泡图)

将第三个变量编码为标记大小来创建气泡图:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
countries = ['US', 'China', 'India', 'Germany', 'Japan', 'UK', 'Brazil', 'France']
gdp = np.array([21.43, 14.34, 2.87, 3.86, 5.08, 2.83, 1.87, 2.72])
population = np.array([331, 1402, 1380, 83, 126, 67, 213, 67])
growth = np.array([2.3, 5.8, 6.5, 1.1, 0.8, 1.4, 1.2, 1.5])
 
plt.figure(figsize=(10, 7))
scatter = plt.scatter(gdp, growth,
    s=population * 2,   # 缩放人口以获得可见大小
    c=range(len(countries)),
    cmap='tab10',
    alpha=0.6,
    edgecolors='black',
)
 
for i, country in enumerate(countries):
    plt.annotate(country, (gdp[i], growth[i]),
        textcoords="offset points", xytext=(10, 5), fontsize=9)
 
plt.xlabel('GDP(万亿美元)')
plt.ylabel('GDP 增长率(%)')
plt.title('GDP vs 增长率(气泡大小 = 人口)')
plt.show()

添加回归线

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100) * 3
y = 1.5 * x + np.random.randn(100) * 2
 
# 拟合线性回归
coefficients = np.polyfit(x, y, 1)
poly = np.poly1d(coefficients)
x_line = np.linspace(x.min(), x.max(), 100)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, s=40, label='数据')
plt.plot(x_line, poly(x_line), 'r-', linewidth=2,
    label=f'y = {coefficients[0]:.2f}x + {coefficients[1]:.2f}')
 
plt.xlabel('X')
plt.ylabel('Y')
plt.title('带回归线的散点图')
plt.legend()
plt.show()

处理重叠点

当点大量重叠时,使用透明度、较小的标记或基于密度的技术:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(5000)
y = np.random.randn(5000)
 
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
# 方法 1:透明度
axes[0].scatter(x, y, alpha=0.1, s=10)
axes[0].set_title('Alpha 透明度')
 
# 方法 2:小标记
axes[1].scatter(x, y, s=1, c='black')
axes[1].set_title('小标记')
 
# 方法 3:2D 直方图(hexbin)
axes[2].hexbin(x, y, gridsize=30, cmap='YlOrRd')
axes[2].set_title('Hexbin 密度')
 
plt.tight_layout()
plt.show()

多子图

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 100
 
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
 
# 图 1:线性
x1 = np.random.randn(n)
axes[0, 0].scatter(x1, 2 * x1 + np.random.randn(n) * 0.5, c='steelblue', s=30)
axes[0, 0].set_title('线性关系')
 
# 图 2:二次
x2 = np.linspace(-3, 3, n)
axes[0, 1].scatter(x2, x2**2 + np.random.randn(n) * 0.5, c='coral', s=30)
axes[0, 1].set_title('二次关系')
 
# 图 3:聚类
for i, c in enumerate(['red', 'blue', 'green']):
    cx = np.random.randn(30) + i * 3
    cy = np.random.randn(30) + i * 2
    axes[1, 0].scatter(cx, cy, c=c, s=30, alpha=0.7)
axes[1, 0].set_title('聚类数据')
 
# 图 4:无相关性
axes[1, 1].scatter(np.random.randn(n), np.random.randn(n), c='purple', s=30, alpha=0.5)
axes[1, 1].set_title('无相关性')
 
plt.tight_layout()
plt.show()

使用 PyGWalker 创建交互式散点图

对于探索性数据分析,静态散点图只是起点。PyGWalker (opens in a new tab) 将你的 pandas DataFrame 转换为 Jupyter 中的交互式 Tableau 风格界面。你可以将列拖放到轴上、添加颜色和大小编码、过滤数据——全部无需编写额外代码:

import pandas as pd
import pygwalker as pyg
 
df = pd.DataFrame({'x': x, 'y': y, 'category': np.random.choice(['A', 'B', 'C'], len(x))})
walker = pyg.walk(df)

plt.scatter() 参数参考

参数类型描述
x, y类数组数据位置
s标量或数组标记大小(点^2)
c颜色或数组标记颜色。数组用于颜色映射
markerstr标记样式('o', 's', '^' 等)
cmapstr 或 Colormapc 为数值时的颜色映射表
alphafloat (0-1)透明度
edgecolors颜色标记边框颜色
linewidthsfloat标记边框宽度
vmin, vmaxfloat颜色映射范围限制
labelstr图例标签

FAQ

如何在 Matplotlib 中创建散点图?

使用 plt.scatter(x, y),其中 x 和 y 是相同长度的数组。添加 plt.xlabel()plt.ylabel()plt.title() 作为标签。调用 plt.show() 来显示图表。

如何按类别给散点图点着色?

遍历类别,对每个类别使用不同的 c 参数和 label 调用 plt.scatter()。然后调用 plt.legend() 显示图例。或者,将数值数组传递给 c 并配合颜色映射表进行连续着色。

如何向散点图添加趋势线?

使用 np.polyfit(x, y, 阶数) 拟合多项式,从系数创建 np.poly1d(),然后用 plt.plot() 绘制。阶数=1 时给出线性回归线。

plt.scatter() 和 plt.plot() 有什么区别?

plt.scatter() 创建单独的标记,可以逐点控制大小、颜色和形状。plt.plot() 配合标记样式创建外观统一的连接点。当点需要单独样式时使用 scatter();折线图或统一标记使用 plot()

如何处理散点图中的重叠点?

使用 alpha(透明度)来显示密度,减小标记大小 s,使用 plt.hexbin() 创建密度热图,或用小的随机偏移量轻微抖动点。

总结

Matplotlib 的 plt.scatter() 是 Python 中创建散点图的标准工具。对于基础探索,简单的 plt.scatter(x, y) 就够了。对于出版级质量的图表,可利用颜色映射表示类别、大小编码表示第三变量、回归线表示趋势、透明度处理密集数据。掌握这些技术,你就能有效地可视化任何双变量关系。

📚