Skip to content

Matplotlib Scatter Plot: Complete Guide to plt.scatter()

Updated on

Scatter plots are the go-to visualization for exploring relationships between two numerical variables. But creating effective scatter plots -- ones that reveal patterns, clusters, and outliers without becoming a cluttered mess -- requires more than a basic plt.scatter() call. You need color mapping for categories, size encoding for a third variable, proper axis labels, and handling for overlapping points.

Matplotlib's plt.scatter() handles all of this with a rich parameter set. This guide covers everything from basic scatter plots to advanced techniques like bubble charts, regression lines, and multi-panel scatter matrices.

📚

Basic Scatter Plot

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100) * 0.5
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y)
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.title('Basic Scatter Plot')
plt.show()

Customizing Markers

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y,
    s=100,              # marker size
    c='steelblue',      # color
    marker='o',         # marker shape
    alpha=0.7,          # transparency
    edgecolors='black', # border color
    linewidths=0.5,     # border width
)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Customized Scatter Plot')
plt.show()

Common Marker Shapes

MarkerSymbolDescription
'o'CircleDefault
's'Square
'^'Triangle up
'D'Diamond
'*'Star
'+'Plus
'x'Cross
'.'PointSmall, for dense data

Color by Category

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 50
 
# Three categories
categories = ['A', 'B', 'C']
colors = ['#e74c3c', '#3498db', '#2ecc71']
 
plt.figure(figsize=(8, 6))
for cat, color in zip(categories, colors):
    x = np.random.randn(n) + (categories.index(cat) * 2)
    y = np.random.randn(n) + (categories.index(cat) * 1.5)
    plt.scatter(x, y, c=color, label=cat, alpha=0.7, s=60, edgecolors='white')
 
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Scatter Plot Colored by Category')
plt.legend()
plt.show()

Color Mapping (Continuous Variable)

Use the c parameter with a numeric array and a colormap to encode a third variable as color:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(200)
y = np.random.randn(200)
values = x ** 2 + y ** 2  # Distance from origin
 
plt.figure(figsize=(8, 6))
scatter = plt.scatter(x, y, c=values, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label='Distance from Origin')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot with Color Mapping')
plt.show()

Popular Colormaps

ColormapTypeBest For
'viridis'SequentialDefault, perceptually uniform
'plasma'SequentialHigh contrast
'coolwarm'DivergingPositive/negative values
'RdYlGn'DivergingGood/bad ranges
'Set1'QualitativeCategorical data

Size Encoding (Bubble Chart)

Encode a third variable as marker size to create a bubble chart:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
countries = ['US', 'China', 'India', 'Germany', 'Japan', 'UK', 'Brazil', 'France']
gdp = np.array([21.43, 14.34, 2.87, 3.86, 5.08, 2.83, 1.87, 2.72])
population = np.array([331, 1402, 1380, 83, 126, 67, 213, 67])
growth = np.array([2.3, 5.8, 6.5, 1.1, 0.8, 1.4, 1.2, 1.5])
 
plt.figure(figsize=(10, 7))
scatter = plt.scatter(gdp, growth,
    s=population * 2,   # Scale population for visible sizes
    c=range(len(countries)),
    cmap='tab10',
    alpha=0.6,
    edgecolors='black',
)
 
for i, country in enumerate(countries):
    plt.annotate(country, (gdp[i], growth[i]),
        textcoords="offset points", xytext=(10, 5), fontsize=9)
 
plt.xlabel('GDP (Trillion USD)')
plt.ylabel('GDP Growth Rate (%)')
plt.title('GDP vs Growth Rate (bubble size = population)')
plt.show()

Adding a Regression Line

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100) * 3
y = 1.5 * x + np.random.randn(100) * 2
 
# Fit linear regression
coefficients = np.polyfit(x, y, 1)
poly = np.poly1d(coefficients)
x_line = np.linspace(x.min(), x.max(), 100)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, s=40, label='Data')
plt.plot(x_line, poly(x_line), 'r-', linewidth=2,
    label=f'y = {coefficients[0]:.2f}x + {coefficients[1]:.2f}')
 
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot with Regression Line')
plt.legend()
plt.show()

Handling Overlapping Points

When points overlap heavily, use transparency, smaller markers, or density-based techniques:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(5000)
y = np.random.randn(5000)
 
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
# Method 1: Transparency
axes[0].scatter(x, y, alpha=0.1, s=10)
axes[0].set_title('Alpha Transparency')
 
# Method 2: Small markers
axes[1].scatter(x, y, s=1, c='black')
axes[1].set_title('Small Markers')
 
# Method 3: 2D histogram (hexbin)
axes[2].hexbin(x, y, gridsize=30, cmap='YlOrRd')
axes[2].set_title('Hexbin Density')
 
plt.tight_layout()
plt.show()

Multiple Subplots

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 100
 
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
 
# Plot 1: Linear
x1 = np.random.randn(n)
axes[0, 0].scatter(x1, 2 * x1 + np.random.randn(n) * 0.5, c='steelblue', s=30)
axes[0, 0].set_title('Linear Relationship')
 
# Plot 2: Quadratic
x2 = np.linspace(-3, 3, n)
axes[0, 1].scatter(x2, x2**2 + np.random.randn(n) * 0.5, c='coral', s=30)
axes[0, 1].set_title('Quadratic Relationship')
 
# Plot 3: Clusters
for i, c in enumerate(['red', 'blue', 'green']):
    cx = np.random.randn(30) + i * 3
    cy = np.random.randn(30) + i * 2
    axes[1, 0].scatter(cx, cy, c=c, s=30, alpha=0.7)
axes[1, 0].set_title('Clustered Data')
 
# Plot 4: No correlation
axes[1, 1].scatter(np.random.randn(n), np.random.randn(n), c='purple', s=30, alpha=0.5)
axes[1, 1].set_title('No Correlation')
 
plt.tight_layout()
plt.show()

Interactive Scatter Plots with PyGWalker

For exploratory data analysis, static scatter plots are just the starting point. PyGWalker (opens in a new tab) turns your pandas DataFrame into an interactive Tableau-style interface directly in Jupyter. You can drag columns onto axes, add color and size encodings, and filter data -- all without writing additional code:

import pandas as pd
import pygwalker as pyg
 
df = pd.DataFrame({'x': x, 'y': y, 'category': np.random.choice(['A', 'B', 'C'], len(x))})
walker = pyg.walk(df)

plt.scatter() Parameter Reference

ParameterTypeDescription
x, yarray-likeData positions
sscalar or arrayMarker size(s) in points^2
ccolor or arrayMarker color(s). Array for colormap
markerstrMarker style ('o', 's', '^', etc.)
cmapstr or ColormapColormap when c is numeric
alphafloat (0-1)Transparency
edgecolorscolorMarker edge color
linewidthsfloatMarker edge width
vmin, vmaxfloatColormap range limits
labelstrLegend label

FAQ

How do I create a scatter plot in Matplotlib?

Use plt.scatter(x, y) where x and y are arrays of the same length. Add plt.xlabel(), plt.ylabel(), and plt.title() for labels. Call plt.show() to display the plot.

How do I color scatter plot points by category?

Loop over categories and call plt.scatter() for each one with a different c parameter and a label. Then call plt.legend() to show the legend. Alternatively, pass a numeric array to c with a colormap for continuous coloring.

How do I add a trend line to a scatter plot?

Use np.polyfit(x, y, degree) to fit a polynomial, create a np.poly1d() from the coefficients, and plot it with plt.plot(). For degree=1, this gives a linear regression line.

What is the difference between plt.scatter() and plt.plot()?

plt.scatter() creates individual markers with per-point control over size, color, and shape. plt.plot() with a marker style creates connected points with uniform appearance. Use scatter() when points need individual styling; use plot() for line charts or uniform markers.

How do I handle overlapping points in a scatter plot?

Use alpha (transparency) to reveal density, reduce marker s (size), use plt.hexbin() for density heatmaps, or jitter the points slightly with small random offsets.

Conclusion

Matplotlib's plt.scatter() is the standard tool for creating scatter plots in Python. For basic exploration, a simple plt.scatter(x, y) suffices. For publication-quality figures, leverage color mapping for categories, size encoding for a third variable, regression lines for trends, and transparency for dense data. Master these techniques and you can visualize any bivariate relationship effectively.

📚