Matplotlib Streudiagramm: Vollständige Anleitung zu plt.scatter()
Updated on
Streudiagramme sind die bevorzugte Visualisierung zur Untersuchung von Beziehungen zwischen zwei numerischen Variablen. Aber effektive Streudiagramme zu erstellen -- solche, die Muster, Cluster und Ausreißer aufdecken, ohne zu einem unübersichtlichen Durcheinander zu werden -- erfordert mehr als einen einfachen plt.scatter()-Aufruf. Sie benötigen Farbzuordnung für Kategorien, Größenkodierung für eine dritte Variable, korrekte Achsenbeschriftungen und den Umgang mit überlappenden Punkten.
Matplotlibs plt.scatter() bewältigt all dies mit einem umfangreichen Parametersatz. Dieser Leitfaden deckt alles ab, von einfachen Streudiagrammen bis hin zu fortgeschrittenen Techniken wie Blasendiagrammen, Regressionslinien und Mehrfach-Streudiagramm-Matrizen.
Einfaches Streudiagramm
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100) * 0.5
plt.figure(figsize=(8, 6))
plt.scatter(x, y)
plt.xlabel('X-Werte')
plt.ylabel('Y-Werte')
plt.title('Einfaches Streudiagramm')
plt.show()Marker anpassen
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
plt.figure(figsize=(8, 6))
plt.scatter(x, y,
s=100, # Markergröße
c='steelblue', # Farbe
marker='o', # Markerform
alpha=0.7, # Transparenz
edgecolors='black', # Randfarbe
linewidths=0.5, # Randbreite
)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Angepasstes Streudiagramm')
plt.show()Häufige Markerformen
| Marker | Symbol | Beschreibung |
|---|---|---|
'o' | Kreis | Standard |
's' | Quadrat | |
'^' | Dreieck oben | |
'D' | Diamant | |
'*' | Stern | |
'+' | Plus | |
'x' | Kreuz | |
'.' | Punkt | Klein, für dichte Daten |
Farbe nach Kategorie
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
n = 50
# Drei Kategorien
categories = ['A', 'B', 'C']
colors = ['#e74c3c', '#3498db', '#2ecc71']
plt.figure(figsize=(8, 6))
for cat, color in zip(categories, colors):
x = np.random.randn(n) + (categories.index(cat) * 2)
y = np.random.randn(n) + (categories.index(cat) * 1.5)
plt.scatter(x, y, c=color, label=cat, alpha=0.7, s=60, edgecolors='white')
plt.xlabel('Merkmal 1')
plt.ylabel('Merkmal 2')
plt.title('Streudiagramm nach Kategorie eingefärbt')
plt.legend()
plt.show()Farbzuordnung (Kontinuierliche Variable)
Verwenden Sie den Parameter c mit einem numerischen Array und einer Farbskala, um eine dritte Variable als Farbe zu kodieren:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(200)
y = np.random.randn(200)
values = x ** 2 + y ** 2 # Abstand vom Ursprung
plt.figure(figsize=(8, 6))
scatter = plt.scatter(x, y, c=values, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label='Abstand vom Ursprung')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Streudiagramm mit Farbzuordnung')
plt.show()Beliebte Farbskalen
| Farbskala | Typ | Am besten für |
|---|---|---|
'viridis' | Sequenziell | Standard, wahrnehmungsmäßig gleichmäßig |
'plasma' | Sequenziell | Hoher Kontrast |
'coolwarm' | Divergierend | Positive/negative Werte |
'RdYlGn' | Divergierend | Gut/schlecht Bereiche |
'Set1' | Qualitativ | Kategoriale Daten |
Größenkodierung (Blasendiagramm)
Kodieren Sie eine dritte Variable als Markergröße, um ein Blasendiagramm zu erstellen:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
countries = ['US', 'China', 'India', 'Germany', 'Japan', 'UK', 'Brazil', 'France']
gdp = np.array([21.43, 14.34, 2.87, 3.86, 5.08, 2.83, 1.87, 2.72])
population = np.array([331, 1402, 1380, 83, 126, 67, 213, 67])
growth = np.array([2.3, 5.8, 6.5, 1.1, 0.8, 1.4, 1.2, 1.5])
plt.figure(figsize=(10, 7))
scatter = plt.scatter(gdp, growth,
s=population * 2, # Bevölkerung für sichtbare Größen skalieren
c=range(len(countries)),
cmap='tab10',
alpha=0.6,
edgecolors='black',
)
for i, country in enumerate(countries):
plt.annotate(country, (gdp[i], growth[i]),
textcoords="offset points", xytext=(10, 5), fontsize=9)
plt.xlabel('BIP (Billionen USD)')
plt.ylabel('BIP-Wachstumsrate (%)')
plt.title('BIP vs Wachstumsrate (Blasengröße = Bevölkerung)')
plt.show()Regressionslinie hinzufügen
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(100) * 3
y = 1.5 * x + np.random.randn(100) * 2
# Lineare Regression anpassen
coefficients = np.polyfit(x, y, 1)
poly = np.poly1d(coefficients)
x_line = np.linspace(x.min(), x.max(), 100)
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, s=40, label='Daten')
plt.plot(x_line, poly(x_line), 'r-', linewidth=2,
label=f'y = {coefficients[0]:.2f}x + {coefficients[1]:.2f}')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Streudiagramm mit Regressionslinie')
plt.legend()
plt.show()Überlappende Punkte behandeln
Wenn sich Punkte stark überlappen, verwenden Sie Transparenz, kleinere Marker oder dichtebasierte Techniken:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
x = np.random.randn(5000)
y = np.random.randn(5000)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Methode 1: Transparenz
axes[0].scatter(x, y, alpha=0.1, s=10)
axes[0].set_title('Alpha-Transparenz')
# Methode 2: Kleine Marker
axes[1].scatter(x, y, s=1, c='black')
axes[1].set_title('Kleine Marker')
# Methode 3: 2D-Histogramm (Hexbin)
axes[2].hexbin(x, y, gridsize=30, cmap='YlOrRd')
axes[2].set_title('Hexbin-Dichte')
plt.tight_layout()
plt.show()Mehrere Unterdiagramme
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
n = 100
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
# Diagramm 1: Linear
x1 = np.random.randn(n)
axes[0, 0].scatter(x1, 2 * x1 + np.random.randn(n) * 0.5, c='steelblue', s=30)
axes[0, 0].set_title('Lineare Beziehung')
# Diagramm 2: Quadratisch
x2 = np.linspace(-3, 3, n)
axes[0, 1].scatter(x2, x2**2 + np.random.randn(n) * 0.5, c='coral', s=30)
axes[0, 1].set_title('Quadratische Beziehung')
# Diagramm 3: Cluster
for i, c in enumerate(['red', 'blue', 'green']):
cx = np.random.randn(30) + i * 3
cy = np.random.randn(30) + i * 2
axes[1, 0].scatter(cx, cy, c=c, s=30, alpha=0.7)
axes[1, 0].set_title('Geclusterte Daten')
# Diagramm 4: Keine Korrelation
axes[1, 1].scatter(np.random.randn(n), np.random.randn(n), c='purple', s=30, alpha=0.5)
axes[1, 1].set_title('Keine Korrelation')
plt.tight_layout()
plt.show()Interaktive Streudiagramme mit PyGWalker
Für explorative Datenanalyse sind statische Streudiagramme nur der Anfang. PyGWalker (opens in a new tab) verwandelt Ihren pandas DataFrame in eine interaktive Tableau-ähnliche Oberfläche direkt in Jupyter. Sie können Spalten auf Achsen ziehen, Farb- und Größenkodierungen hinzufügen und Daten filtern -- alles ohne zusätzlichen Code:
import pandas as pd
import pygwalker as pyg
df = pd.DataFrame({'x': x, 'y': y, 'category': np.random.choice(['A', 'B', 'C'], len(x))})
walker = pyg.walk(df)plt.scatter() Parameterreferenz
| Parameter | Typ | Beschreibung |
|---|---|---|
x, y | array-ähnlich | Datenpositionen |
s | Skalar oder Array | Markergröße(n) in Punkte^2 |
c | Farbe oder Array | Markerfarbe(n). Array für Farbskala |
marker | str | Markerstil ('o', 's', '^', usw.) |
cmap | str oder Colormap | Farbskala wenn c numerisch ist |
alpha | float (0-1) | Transparenz |
edgecolors | Farbe | Marker-Randfarbe |
linewidths | float | Marker-Randbreite |
vmin, vmax | float | Farbskala-Bereichsgrenzen |
label | str | Legende-Beschriftung |
FAQ
Wie erstelle ich ein Streudiagramm in Matplotlib?
Verwenden Sie plt.scatter(x, y), wobei x und y Arrays gleicher Länge sind. Fügen Sie plt.xlabel(), plt.ylabel() und plt.title() für Beschriftungen hinzu. Rufen Sie plt.show() auf, um das Diagramm anzuzeigen.
Wie färbe ich Streudiagramm-Punkte nach Kategorie ein?
Iterieren Sie über Kategorien und rufen Sie plt.scatter() für jede mit einem anderen c-Parameter und einem label auf. Rufen Sie dann plt.legend() auf, um die Legende anzuzeigen. Alternativ übergeben Sie ein numerisches Array an c mit einer Farbskala für kontinuierliche Einfärbung.
Wie füge ich eine Trendlinie zu einem Streudiagramm hinzu?
Verwenden Sie np.polyfit(x, y, grad), um ein Polynom anzupassen, erstellen Sie ein np.poly1d() aus den Koeffizienten und zeichnen Sie es mit plt.plot(). Für grad=1 ergibt dies eine lineare Regressionslinie.
Was ist der Unterschied zwischen plt.scatter() und plt.plot()?
plt.scatter() erstellt einzelne Marker mit punkt-individueller Kontrolle über Größe, Farbe und Form. plt.plot() mit einem Markerstil erstellt verbundene Punkte mit einheitlichem Erscheinungsbild. Verwenden Sie scatter(), wenn Punkte individuelles Styling benötigen; verwenden Sie plot() für Liniendiagramme oder einheitliche Marker.
Wie gehe ich mit überlappenden Punkten in einem Streudiagramm um?
Verwenden Sie alpha (Transparenz), um die Dichte sichtbar zu machen, reduzieren Sie die Markergröße s, verwenden Sie plt.hexbin() für Dichte-Heatmaps oder verschieben Sie die Punkte leicht mit kleinen zufälligen Offsets.
Fazit
Matplotlibs plt.scatter() ist das Standardwerkzeug zum Erstellen von Streudiagrammen in Python. Für die grundlegende Exploration reicht ein einfaches plt.scatter(x, y). Für publikationsreife Abbildungen nutzen Sie Farbzuordnung für Kategorien, Größenkodierung für eine dritte Variable, Regressionslinien für Trends und Transparenz für dichte Daten. Beherrschen Sie diese Techniken und Sie können jede bivariate Beziehung effektiv visualisieren.