Sklearn train_test_split: Split Data for Machine Learning in Python
Updated on
Training a machine learning model on the same data you use to evaluate it gives misleadingly high accuracy. The model memorizes the training data instead of learning generalizable patterns -- a problem called overfitting. You need a separate test set that the model never sees during training to get honest performance metrics.
Scikit-learn's train_test_split() is the standard way to divide datasets into training and test portions. It handles arrays, DataFrames, and sparse matrices, with options for stratification, reproducibility, and custom split ratios.
Basic Usage
from sklearn.model_selection import train_test_split
import numpy as np
# Sample data: 100 samples, 5 features
X = np.random.randn(100, 5)
y = np.random.randint(0, 2, 100) # Binary labels
# Split: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
# Training set: 80 samples
# Test set: 20 samplesKey Parameters
| Parameter | Default | Description |
|---|---|---|
test_size | 0.25 | Fraction (0.0-1.0) or absolute number of test samples |
train_size | None | Fraction or number of training samples (complement of test_size) |
random_state | None | Seed for reproducible splits |
shuffle | True | Whether to shuffle data before splitting |
stratify | None | Array to use for stratified splitting |
With Pandas DataFrames
from sklearn.model_selection import train_test_split
import pandas as pd
df = pd.DataFrame({
'age': [25, 30, 35, 40, 45, 50, 55, 60, 28, 33],
'income': [40, 50, 60, 70, 80, 90, 100, 110, 45, 55],
'purchased': [0, 0, 1, 1, 1, 1, 1, 1, 0, 0]
})
X = df[['age', 'income']]
y = df['purchased']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
print(f"Train shape: {X_train.shape}") # (7, 2)
print(f"Test shape: {X_test.shape}") # (3, 2)
print(type(X_train)) # <class 'pandas.core.frame.DataFrame'>DataFrames stay as DataFrames after splitting -- column names and indices are preserved.
random_state: Reproducible Splits
Without random_state, you get a different split each time:
from sklearn.model_selection import train_test_split
import numpy as np
X = np.arange(10).reshape(5, 2)
y = np.array([0, 0, 1, 1, 1])
# Without random_state: different split each run
_, X_test1, _, _ = train_test_split(X, y, test_size=0.4)
_, X_test2, _, _ = train_test_split(X, y, test_size=0.4)
print(np.array_equal(X_test1, X_test2)) # Likely False
# With random_state: same split every time
_, X_test3, _, _ = train_test_split(X, y, test_size=0.4, random_state=42)
_, X_test4, _, _ = train_test_split(X, y, test_size=0.4, random_state=42)
print(np.array_equal(X_test3, X_test4)) # TrueAlways set random_state for reproducibility. Use any integer -- 42 is conventional, but the specific number does not matter.
Stratified Splitting
For imbalanced datasets, a random split might put most minority samples in one set. Stratification ensures both sets have the same class proportions:
from sklearn.model_selection import train_test_split
import numpy as np
from collections import Counter
# Imbalanced dataset: 90% class 0, 10% class 1
np.random.seed(42)
X = np.random.randn(200, 4)
y = np.array([0] * 180 + [1] * 20)
# Without stratification
_, _, y_train_bad, y_test_bad = train_test_split(
X, y, test_size=0.2, random_state=42
)
print("Without stratify:")
print(f" Train: {Counter(y_train_bad)}")
print(f" Test: {Counter(y_test_bad)}")
# With stratification
_, _, y_train_good, y_test_good = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print("\nWith stratify=y:")
print(f" Train: {Counter(y_train_good)}")
print(f" Test: {Counter(y_test_good)}")
# Train: Counter({0: 144, 1: 16}) -- 10% class 1
# Test: Counter({0: 36, 1: 4}) -- 10% class 1When to Use Stratification
| Scenario | Use Stratify? |
|---|---|
| Balanced classes (50/50) | Optional |
| Imbalanced classes (90/10) | Yes |
| Multi-class classification | Yes |
| Regression (continuous target) | No (not supported) |
| Small datasets (< 100 samples) | Yes (prevents empty classes) |
Train/Validation/Test Split
For hyperparameter tuning, you need three sets: train, validation, and test. Apply train_test_split twice:
from sklearn.model_selection import train_test_split
import numpy as np
X = np.random.randn(1000, 10)
y = np.random.randint(0, 3, 1000)
# First split: 80% train+val, 20% test
X_temp, X_test, y_temp, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Second split: 75% train, 25% val (of the 80% = 60/20/20 overall)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp
)
print(f"Train: {X_train.shape[0]} samples (60%)")
print(f"Validation: {X_val.shape[0]} samples (20%)")
print(f"Test: {X_test.shape[0]} samples (20%)")Complete ML Pipeline Example
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
# Generate sample data
np.random.seed(42)
X = np.random.randn(500, 8)
y = (X[:, 0] + X[:, 1] > 0).astype(int)
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Scale features (fit on train only!)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test) # Use train statistics
# Train model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)
# Evaluate
y_pred = model.predict(X_test_scaled)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(classification_report(y_test, y_pred))Critical rule: Fit the scaler (and any preprocessing) on the training set only. Apply the same transformation to the test set. Fitting on the full dataset causes data leakage.
Common Split Ratios
| Split | Train | Test | When to Use |
|---|---|---|---|
| 80/20 | 80% | 20% | Default choice, most datasets |
| 70/30 | 70% | 30% | Small datasets, need larger test set |
| 90/10 | 90% | 10% | Large datasets (10k+ samples) |
| 60/20/20 | 60% | 20% val + 20% test | When tuning hyperparameters |
Exploring Model Results
After splitting and training your model, PyGWalker (opens in a new tab) lets you interactively explore predictions vs actuals, feature distributions across train/test sets, and error patterns in Jupyter:
import pandas as pd
import pygwalker as pyg
results = pd.DataFrame({
'actual': y_test,
'predicted': y_pred,
'correct': y_test == y_pred
})
walker = pyg.walk(results)FAQ
What does train_test_split do in sklearn?
train_test_split() randomly divides arrays or DataFrames into two subsets: one for training and one for testing. It ensures that model evaluation uses data the model hasn't seen during training, giving honest performance estimates.
What is the best train/test split ratio?
80/20 is the standard default. Use 70/30 for smaller datasets where you need a reliable test set, and 90/10 for large datasets (10k+ samples) where 10% is still substantial. For hyperparameter tuning, use 60/20/20 (train/val/test).
What does random_state do in train_test_split?
random_state sets the random seed for the shuffling that happens before splitting. Using the same random_state value produces the same split every time, making your results reproducible. Any integer works.
When should I use stratify in train_test_split?
Use stratify=y when your target variable is imbalanced (e.g., 95% negative, 5% positive) or when you have a small dataset. Stratification ensures both train and test sets have the same proportion of each class.
How do I split data into train, validation, and test sets?
Call train_test_split twice. First split into train+val and test (e.g., 80/20). Then split train+val into train and val (e.g., 75/25 of the 80%, giving 60/20/20 overall). Alternatively, use sklearn.model_selection.KFold for cross-validation.
Conclusion
train_test_split() is the foundation of every machine learning workflow in Python. Always use it before training -- never evaluate on training data. Set random_state for reproducibility, use stratify=y for imbalanced classes, and remember to fit preprocessing steps on the training set only. For model selection, split three ways (train/validation/test) or use cross-validation for more robust estimates.