Matplotlib Scatter Plot: Complete Guide to plt.scatter()
Updated on
Scatter plots are the go-to visualization for exploring relationships between two numerical variables. But creating effective scatter plots -- ones that reveal patterns, clusters, and outliers without becoming a cluttered mess -- requires more than a basic plt.scatter() call. You need color mapping for categories, size encoding for a third variable, proper axis labels, and handling for overlapping points.
Matplotlib's plt.scatter() handles all of this with a rich parameter set. This guide covers everything from basic scatter plots to advanced techniques like bubble charts, regression lines, and multi-panel scatter matrices.
Basic Scatter Plot
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100) * 0.5
plt.figure(figsize=(8, 6))
plt.scatter(x, y)
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.title('Basic Scatter Plot')
plt.show()Customizing Markers
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
plt.figure(figsize=(8, 6))
plt.scatter(x, y,
s=100, # marker size
c='steelblue', # color
marker='o', # marker shape
alpha=0.7, # transparency
edgecolors='black', # border color
linewidths=0.5, # border width
)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Customized Scatter Plot')
plt.show()Common Marker Shapes
| Marker | Symbol | Description |
|---|---|---|
'o' | Circle | Default |
's' | Square | |
'^' | Triangle up | |
'D' | Diamond | |
'*' | Star | |
'+' | Plus | |
'x' | Cross | |
'.' | Point | Small, for dense data |
Color by Category
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
n = 50
# Three categories
categories = ['A', 'B', 'C']
colors = ['#e74c3c', '#3498db', '#2ecc71']
plt.figure(figsize=(8, 6))
for cat, color in zip(categories, colors):
x = np.random.randn(n) + (categories.index(cat) * 2)
y = np.random.randn(n) + (categories.index(cat) * 1.5)
plt.scatter(x, y, c=color, label=cat, alpha=0.7, s=60, edgecolors='white')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Scatter Plot Colored by Category')
plt.legend()
plt.show()Color Mapping (Continuous Variable)
Use the c parameter with a numeric array and a colormap to encode a third variable as color:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(200)
y = np.random.randn(200)
values = x ** 2 + y ** 2 # Distance from origin
plt.figure(figsize=(8, 6))
scatter = plt.scatter(x, y, c=values, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label='Distance from Origin')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot with Color Mapping')
plt.show()Popular Colormaps
| Colormap | Type | Best For |
|---|---|---|
'viridis' | Sequential | Default, perceptually uniform |
'plasma' | Sequential | High contrast |
'coolwarm' | Diverging | Positive/negative values |
'RdYlGn' | Diverging | Good/bad ranges |
'Set1' | Qualitative | Categorical data |
Size Encoding (Bubble Chart)
Encode a third variable as marker size to create a bubble chart:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
countries = ['US', 'China', 'India', 'Germany', 'Japan', 'UK', 'Brazil', 'France']
gdp = np.array([21.43, 14.34, 2.87, 3.86, 5.08, 2.83, 1.87, 2.72])
population = np.array([331, 1402, 1380, 83, 126, 67, 213, 67])
growth = np.array([2.3, 5.8, 6.5, 1.1, 0.8, 1.4, 1.2, 1.5])
plt.figure(figsize=(10, 7))
scatter = plt.scatter(gdp, growth,
s=population * 2, # Scale population for visible sizes
c=range(len(countries)),
cmap='tab10',
alpha=0.6,
edgecolors='black',
)
for i, country in enumerate(countries):
plt.annotate(country, (gdp[i], growth[i]),
textcoords="offset points", xytext=(10, 5), fontsize=9)
plt.xlabel('GDP (Trillion USD)')
plt.ylabel('GDP Growth Rate (%)')
plt.title('GDP vs Growth Rate (bubble size = population)')
plt.show()Adding a Regression Line
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(100) * 3
y = 1.5 * x + np.random.randn(100) * 2
# Fit linear regression
coefficients = np.polyfit(x, y, 1)
poly = np.poly1d(coefficients)
x_line = np.linspace(x.min(), x.max(), 100)
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, s=40, label='Data')
plt.plot(x_line, poly(x_line), 'r-', linewidth=2,
label=f'y = {coefficients[0]:.2f}x + {coefficients[1]:.2f}')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot with Regression Line')
plt.legend()
plt.show()Handling Overlapping Points
When points overlap heavily, use transparency, smaller markers, or density-based techniques:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(5000)
y = np.random.randn(5000)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Method 1: Transparency
axes[0].scatter(x, y, alpha=0.1, s=10)
axes[0].set_title('Alpha Transparency')
# Method 2: Small markers
axes[1].scatter(x, y, s=1, c='black')
axes[1].set_title('Small Markers')
# Method 3: 2D histogram (hexbin)
axes[2].hexbin(x, y, gridsize=30, cmap='YlOrRd')
axes[2].set_title('Hexbin Density')
plt.tight_layout()
plt.show()Multiple Subplots
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
n = 100
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
# Plot 1: Linear
x1 = np.random.randn(n)
axes[0, 0].scatter(x1, 2 * x1 + np.random.randn(n) * 0.5, c='steelblue', s=30)
axes[0, 0].set_title('Linear Relationship')
# Plot 2: Quadratic
x2 = np.linspace(-3, 3, n)
axes[0, 1].scatter(x2, x2**2 + np.random.randn(n) * 0.5, c='coral', s=30)
axes[0, 1].set_title('Quadratic Relationship')
# Plot 3: Clusters
for i, c in enumerate(['red', 'blue', 'green']):
cx = np.random.randn(30) + i * 3
cy = np.random.randn(30) + i * 2
axes[1, 0].scatter(cx, cy, c=c, s=30, alpha=0.7)
axes[1, 0].set_title('Clustered Data')
# Plot 4: No correlation
axes[1, 1].scatter(np.random.randn(n), np.random.randn(n), c='purple', s=30, alpha=0.5)
axes[1, 1].set_title('No Correlation')
plt.tight_layout()
plt.show()Interactive Scatter Plots with PyGWalker
For exploratory data analysis, static scatter plots are just the starting point. PyGWalker (opens in a new tab) turns your pandas DataFrame into an interactive Tableau-style interface directly in Jupyter. You can drag columns onto axes, add color and size encodings, and filter data -- all without writing additional code:
import pandas as pd
import pygwalker as pyg
df = pd.DataFrame({'x': x, 'y': y, 'category': np.random.choice(['A', 'B', 'C'], len(x))})
walker = pyg.walk(df)plt.scatter() Parameter Reference
| Parameter | Type | Description |
|---|---|---|
x, y | array-like | Data positions |
s | scalar or array | Marker size(s) in points^2 |
c | color or array | Marker color(s). Array for colormap |
marker | str | Marker style ('o', 's', '^', etc.) |
cmap | str or Colormap | Colormap when c is numeric |
alpha | float (0-1) | Transparency |
edgecolors | color | Marker edge color |
linewidths | float | Marker edge width |
vmin, vmax | float | Colormap range limits |
label | str | Legend label |
FAQ
How do I create a scatter plot in Matplotlib?
Use plt.scatter(x, y) where x and y are arrays of the same length. Add plt.xlabel(), plt.ylabel(), and plt.title() for labels. Call plt.show() to display the plot.
How do I color scatter plot points by category?
Loop over categories and call plt.scatter() for each one with a different c parameter and a label. Then call plt.legend() to show the legend. Alternatively, pass a numeric array to c with a colormap for continuous coloring.
How do I add a trend line to a scatter plot?
Use np.polyfit(x, y, degree) to fit a polynomial, create a np.poly1d() from the coefficients, and plot it with plt.plot(). For degree=1, this gives a linear regression line.
What is the difference between plt.scatter() and plt.plot()?
plt.scatter() creates individual markers with per-point control over size, color, and shape. plt.plot() with a marker style creates connected points with uniform appearance. Use scatter() when points need individual styling; use plot() for line charts or uniform markers.
How do I handle overlapping points in a scatter plot?
Use alpha (transparency) to reveal density, reduce marker s (size), use plt.hexbin() for density heatmaps, or jitter the points slightly with small random offsets.
Conclusion
Matplotlib's plt.scatter() is the standard tool for creating scatter plots in Python. For basic exploration, a simple plt.scatter(x, y) suffices. For publication-quality figures, leverage color mapping for categories, size encoding for a third variable, regression lines for trends, and transparency for dense data. Master these techniques and you can visualize any bivariate relationship effectively.