Skip to content

Matplotlib Streudiagramm: Vollständige Anleitung zu plt.scatter()

Updated on

Streudiagramme sind die bevorzugte Visualisierung zur Untersuchung von Beziehungen zwischen zwei numerischen Variablen. Aber effektive Streudiagramme zu erstellen -- solche, die Muster, Cluster und Ausreißer aufdecken, ohne zu einem unübersichtlichen Durcheinander zu werden -- erfordert mehr als einen einfachen plt.scatter()-Aufruf. Sie benötigen Farbzuordnung für Kategorien, Größenkodierung für eine dritte Variable, korrekte Achsenbeschriftungen und den Umgang mit überlappenden Punkten.

Matplotlibs plt.scatter() bewältigt all dies mit einem umfangreichen Parametersatz. Dieser Leitfaden deckt alles ab, von einfachen Streudiagrammen bis hin zu fortgeschrittenen Techniken wie Blasendiagrammen, Regressionslinien und Mehrfach-Streudiagramm-Matrizen.

📚

Einfaches Streudiagramm

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100) * 0.5
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y)
plt.xlabel('X-Werte')
plt.ylabel('Y-Werte')
plt.title('Einfaches Streudiagramm')
plt.show()

Marker anpassen

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y,
    s=100,              # Markergröße
    c='steelblue',      # Farbe
    marker='o',         # Markerform
    alpha=0.7,          # Transparenz
    edgecolors='black', # Randfarbe
    linewidths=0.5,     # Randbreite
)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Angepasstes Streudiagramm')
plt.show()

Häufige Markerformen

MarkerSymbolBeschreibung
'o'KreisStandard
's'Quadrat
'^'Dreieck oben
'D'Diamant
'*'Stern
'+'Plus
'x'Kreuz
'.'PunktKlein, für dichte Daten

Farbe nach Kategorie

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 50
 
# Drei Kategorien
categories = ['A', 'B', 'C']
colors = ['#e74c3c', '#3498db', '#2ecc71']
 
plt.figure(figsize=(8, 6))
for cat, color in zip(categories, colors):
    x = np.random.randn(n) + (categories.index(cat) * 2)
    y = np.random.randn(n) + (categories.index(cat) * 1.5)
    plt.scatter(x, y, c=color, label=cat, alpha=0.7, s=60, edgecolors='white')
 
plt.xlabel('Merkmal 1')
plt.ylabel('Merkmal 2')
plt.title('Streudiagramm nach Kategorie eingefärbt')
plt.legend()
plt.show()

Farbzuordnung (Kontinuierliche Variable)

Verwenden Sie den Parameter c mit einem numerischen Array und einer Farbskala, um eine dritte Variable als Farbe zu kodieren:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(200)
y = np.random.randn(200)
values = x ** 2 + y ** 2  # Abstand vom Ursprung
 
plt.figure(figsize=(8, 6))
scatter = plt.scatter(x, y, c=values, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label='Abstand vom Ursprung')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Streudiagramm mit Farbzuordnung')
plt.show()

Beliebte Farbskalen

FarbskalaTypAm besten für
'viridis'SequenziellStandard, wahrnehmungsmäßig gleichmäßig
'plasma'SequenziellHoher Kontrast
'coolwarm'DivergierendPositive/negative Werte
'RdYlGn'DivergierendGut/schlecht Bereiche
'Set1'QualitativKategoriale Daten

Größenkodierung (Blasendiagramm)

Kodieren Sie eine dritte Variable als Markergröße, um ein Blasendiagramm zu erstellen:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
countries = ['US', 'China', 'India', 'Germany', 'Japan', 'UK', 'Brazil', 'France']
gdp = np.array([21.43, 14.34, 2.87, 3.86, 5.08, 2.83, 1.87, 2.72])
population = np.array([331, 1402, 1380, 83, 126, 67, 213, 67])
growth = np.array([2.3, 5.8, 6.5, 1.1, 0.8, 1.4, 1.2, 1.5])
 
plt.figure(figsize=(10, 7))
scatter = plt.scatter(gdp, growth,
    s=population * 2,   # Bevölkerung für sichtbare Größen skalieren
    c=range(len(countries)),
    cmap='tab10',
    alpha=0.6,
    edgecolors='black',
)
 
for i, country in enumerate(countries):
    plt.annotate(country, (gdp[i], growth[i]),
        textcoords="offset points", xytext=(10, 5), fontsize=9)
 
plt.xlabel('BIP (Billionen USD)')
plt.ylabel('BIP-Wachstumsrate (%)')
plt.title('BIP vs Wachstumsrate (Blasengröße = Bevölkerung)')
plt.show()

Regressionslinie hinzufügen

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(100) * 3
y = 1.5 * x + np.random.randn(100) * 2
 
# Lineare Regression anpassen
coefficients = np.polyfit(x, y, 1)
poly = np.poly1d(coefficients)
x_line = np.linspace(x.min(), x.max(), 100)
 
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, s=40, label='Daten')
plt.plot(x_line, poly(x_line), 'r-', linewidth=2,
    label=f'y = {coefficients[0]:.2f}x + {coefficients[1]:.2f}')
 
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Streudiagramm mit Regressionslinie')
plt.legend()
plt.show()

Überlappende Punkte behandeln

Wenn sich Punkte stark überlappen, verwenden Sie Transparenz, kleinere Marker oder dichtebasierte Techniken:

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
x = np.random.randn(5000)
y = np.random.randn(5000)
 
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
# Methode 1: Transparenz
axes[0].scatter(x, y, alpha=0.1, s=10)
axes[0].set_title('Alpha-Transparenz')
 
# Methode 2: Kleine Marker
axes[1].scatter(x, y, s=1, c='black')
axes[1].set_title('Kleine Marker')
 
# Methode 3: 2D-Histogramm (Hexbin)
axes[2].hexbin(x, y, gridsize=30, cmap='YlOrRd')
axes[2].set_title('Hexbin-Dichte')
 
plt.tight_layout()
plt.show()

Mehrere Unterdiagramme

import matplotlib.pyplot as plt
import numpy as np
 
np.random.seed(42)
n = 100
 
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
 
# Diagramm 1: Linear
x1 = np.random.randn(n)
axes[0, 0].scatter(x1, 2 * x1 + np.random.randn(n) * 0.5, c='steelblue', s=30)
axes[0, 0].set_title('Lineare Beziehung')
 
# Diagramm 2: Quadratisch
x2 = np.linspace(-3, 3, n)
axes[0, 1].scatter(x2, x2**2 + np.random.randn(n) * 0.5, c='coral', s=30)
axes[0, 1].set_title('Quadratische Beziehung')
 
# Diagramm 3: Cluster
for i, c in enumerate(['red', 'blue', 'green']):
    cx = np.random.randn(30) + i * 3
    cy = np.random.randn(30) + i * 2
    axes[1, 0].scatter(cx, cy, c=c, s=30, alpha=0.7)
axes[1, 0].set_title('Geclusterte Daten')
 
# Diagramm 4: Keine Korrelation
axes[1, 1].scatter(np.random.randn(n), np.random.randn(n), c='purple', s=30, alpha=0.5)
axes[1, 1].set_title('Keine Korrelation')
 
plt.tight_layout()
plt.show()

Interaktive Streudiagramme mit PyGWalker

Für explorative Datenanalyse sind statische Streudiagramme nur der Anfang. PyGWalker (opens in a new tab) verwandelt Ihren pandas DataFrame in eine interaktive Tableau-ähnliche Oberfläche direkt in Jupyter. Sie können Spalten auf Achsen ziehen, Farb- und Größenkodierungen hinzufügen und Daten filtern -- alles ohne zusätzlichen Code:

import pandas as pd
import pygwalker as pyg
 
df = pd.DataFrame({'x': x, 'y': y, 'category': np.random.choice(['A', 'B', 'C'], len(x))})
walker = pyg.walk(df)

plt.scatter() Parameterreferenz

ParameterTypBeschreibung
x, yarray-ähnlichDatenpositionen
sSkalar oder ArrayMarkergröße(n) in Punkte^2
cFarbe oder ArrayMarkerfarbe(n). Array für Farbskala
markerstrMarkerstil ('o', 's', '^', usw.)
cmapstr oder ColormapFarbskala wenn c numerisch ist
alphafloat (0-1)Transparenz
edgecolorsFarbeMarker-Randfarbe
linewidthsfloatMarker-Randbreite
vmin, vmaxfloatFarbskala-Bereichsgrenzen
labelstrLegende-Beschriftung

FAQ

Wie erstelle ich ein Streudiagramm in Matplotlib?

Verwenden Sie plt.scatter(x, y), wobei x und y Arrays gleicher Länge sind. Fügen Sie plt.xlabel(), plt.ylabel() und plt.title() für Beschriftungen hinzu. Rufen Sie plt.show() auf, um das Diagramm anzuzeigen.

Wie färbe ich Streudiagramm-Punkte nach Kategorie ein?

Iterieren Sie über Kategorien und rufen Sie plt.scatter() für jede mit einem anderen c-Parameter und einem label auf. Rufen Sie dann plt.legend() auf, um die Legende anzuzeigen. Alternativ übergeben Sie ein numerisches Array an c mit einer Farbskala für kontinuierliche Einfärbung.

Wie füge ich eine Trendlinie zu einem Streudiagramm hinzu?

Verwenden Sie np.polyfit(x, y, grad), um ein Polynom anzupassen, erstellen Sie ein np.poly1d() aus den Koeffizienten und zeichnen Sie es mit plt.plot(). Für grad=1 ergibt dies eine lineare Regressionslinie.

Was ist der Unterschied zwischen plt.scatter() und plt.plot()?

plt.scatter() erstellt einzelne Marker mit punkt-individueller Kontrolle über Größe, Farbe und Form. plt.plot() mit einem Markerstil erstellt verbundene Punkte mit einheitlichem Erscheinungsbild. Verwenden Sie scatter(), wenn Punkte individuelles Styling benötigen; verwenden Sie plot() für Liniendiagramme oder einheitliche Marker.

Wie gehe ich mit überlappenden Punkten in einem Streudiagramm um?

Verwenden Sie alpha (Transparenz), um die Dichte sichtbar zu machen, reduzieren Sie die Markergröße s, verwenden Sie plt.hexbin() für Dichte-Heatmaps oder verschieben Sie die Punkte leicht mit kleinen zufälligen Offsets.

Fazit

Matplotlibs plt.scatter() ist das Standardwerkzeug zum Erstellen von Streudiagrammen in Python. Für die grundlegende Exploration reicht ein einfaches plt.scatter(x, y). Für publikationsreife Abbildungen nutzen Sie Farbzuordnung für Kategorien, Größenkodierung für eine dritte Variable, Regressionslinien für Trends und Transparenz für dichte Daten. Beherrschen Sie diese Techniken und Sie können jede bivariate Beziehung effektiv visualisieren.

📚