Skip to content

Sklearn train_test_split: Split Data for Machine Learning in Python

Updated on

Training a machine learning model on the same data you use to evaluate it gives misleadingly high accuracy. The model memorizes the training data instead of learning generalizable patterns -- a problem called overfitting. You need a separate test set that the model never sees during training to get honest performance metrics.

Scikit-learn's train_test_split() is the standard way to divide datasets into training and test portions. It handles arrays, DataFrames, and sparse matrices, with options for stratification, reproducibility, and custom split ratios.

📚

Basic Usage

from sklearn.model_selection import train_test_split
import numpy as np
 
# Sample data: 100 samples, 5 features
X = np.random.randn(100, 5)
y = np.random.randint(0, 2, 100)  # Binary labels
 
# Split: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
 
print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set:     {X_test.shape[0]} samples")
# Training set: 80 samples
# Test set:     20 samples

Key Parameters

ParameterDefaultDescription
test_size0.25Fraction (0.0-1.0) or absolute number of test samples
train_sizeNoneFraction or number of training samples (complement of test_size)
random_stateNoneSeed for reproducible splits
shuffleTrueWhether to shuffle data before splitting
stratifyNoneArray to use for stratified splitting

With Pandas DataFrames

from sklearn.model_selection import train_test_split
import pandas as pd
 
df = pd.DataFrame({
    'age': [25, 30, 35, 40, 45, 50, 55, 60, 28, 33],
    'income': [40, 50, 60, 70, 80, 90, 100, 110, 45, 55],
    'purchased': [0, 0, 1, 1, 1, 1, 1, 1, 0, 0]
})
 
X = df[['age', 'income']]
y = df['purchased']
 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
 
print(f"Train shape: {X_train.shape}")  # (7, 2)
print(f"Test shape:  {X_test.shape}")   # (3, 2)
print(type(X_train))  # <class 'pandas.core.frame.DataFrame'>

DataFrames stay as DataFrames after splitting -- column names and indices are preserved.

random_state: Reproducible Splits

Without random_state, you get a different split each time:

from sklearn.model_selection import train_test_split
import numpy as np
 
X = np.arange(10).reshape(5, 2)
y = np.array([0, 0, 1, 1, 1])
 
# Without random_state: different split each run
_, X_test1, _, _ = train_test_split(X, y, test_size=0.4)
_, X_test2, _, _ = train_test_split(X, y, test_size=0.4)
print(np.array_equal(X_test1, X_test2))  # Likely False
 
# With random_state: same split every time
_, X_test3, _, _ = train_test_split(X, y, test_size=0.4, random_state=42)
_, X_test4, _, _ = train_test_split(X, y, test_size=0.4, random_state=42)
print(np.array_equal(X_test3, X_test4))  # True

Always set random_state for reproducibility. Use any integer -- 42 is conventional, but the specific number does not matter.

Stratified Splitting

For imbalanced datasets, a random split might put most minority samples in one set. Stratification ensures both sets have the same class proportions:

from sklearn.model_selection import train_test_split
import numpy as np
from collections import Counter
 
# Imbalanced dataset: 90% class 0, 10% class 1
np.random.seed(42)
X = np.random.randn(200, 4)
y = np.array([0] * 180 + [1] * 20)
 
# Without stratification
_, _, y_train_bad, y_test_bad = train_test_split(
    X, y, test_size=0.2, random_state=42
)
print("Without stratify:")
print(f"  Train: {Counter(y_train_bad)}")
print(f"  Test:  {Counter(y_test_bad)}")
 
# With stratification
_, _, y_train_good, y_test_good = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
print("\nWith stratify=y:")
print(f"  Train: {Counter(y_train_good)}")
print(f"  Test:  {Counter(y_test_good)}")
# Train: Counter({0: 144, 1: 16})  -- 10% class 1
# Test:  Counter({0: 36, 1: 4})    -- 10% class 1

When to Use Stratification

ScenarioUse Stratify?
Balanced classes (50/50)Optional
Imbalanced classes (90/10)Yes
Multi-class classificationYes
Regression (continuous target)No (not supported)
Small datasets (< 100 samples)Yes (prevents empty classes)

Train/Validation/Test Split

For hyperparameter tuning, you need three sets: train, validation, and test. Apply train_test_split twice:

from sklearn.model_selection import train_test_split
import numpy as np
 
X = np.random.randn(1000, 10)
y = np.random.randint(0, 3, 1000)
 
# First split: 80% train+val, 20% test
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
 
# Second split: 75% train, 25% val (of the 80% = 60/20/20 overall)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp
)
 
print(f"Train:      {X_train.shape[0]} samples (60%)")
print(f"Validation: {X_val.shape[0]} samples (20%)")
print(f"Test:       {X_test.shape[0]} samples (20%)")

Complete ML Pipeline Example

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
 
# Generate sample data
np.random.seed(42)
X = np.random.randn(500, 8)
y = (X[:, 0] + X[:, 1] > 0).astype(int)
 
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
 
# Scale features (fit on train only!)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)  # Use train statistics
 
# Train model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)
 
# Evaluate
y_pred = model.predict(X_test_scaled)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(classification_report(y_test, y_pred))

Critical rule: Fit the scaler (and any preprocessing) on the training set only. Apply the same transformation to the test set. Fitting on the full dataset causes data leakage.

Common Split Ratios

SplitTrainTestWhen to Use
80/2080%20%Default choice, most datasets
70/3070%30%Small datasets, need larger test set
90/1090%10%Large datasets (10k+ samples)
60/20/2060%20% val + 20% testWhen tuning hyperparameters

Exploring Model Results

After splitting and training your model, PyGWalker (opens in a new tab) lets you interactively explore predictions vs actuals, feature distributions across train/test sets, and error patterns in Jupyter:

import pandas as pd
import pygwalker as pyg
 
results = pd.DataFrame({
    'actual': y_test,
    'predicted': y_pred,
    'correct': y_test == y_pred
})
walker = pyg.walk(results)

FAQ

What does train_test_split do in sklearn?

train_test_split() randomly divides arrays or DataFrames into two subsets: one for training and one for testing. It ensures that model evaluation uses data the model hasn't seen during training, giving honest performance estimates.

What is the best train/test split ratio?

80/20 is the standard default. Use 70/30 for smaller datasets where you need a reliable test set, and 90/10 for large datasets (10k+ samples) where 10% is still substantial. For hyperparameter tuning, use 60/20/20 (train/val/test).

What does random_state do in train_test_split?

random_state sets the random seed for the shuffling that happens before splitting. Using the same random_state value produces the same split every time, making your results reproducible. Any integer works.

When should I use stratify in train_test_split?

Use stratify=y when your target variable is imbalanced (e.g., 95% negative, 5% positive) or when you have a small dataset. Stratification ensures both train and test sets have the same proportion of each class.

How do I split data into train, validation, and test sets?

Call train_test_split twice. First split into train+val and test (e.g., 80/20). Then split train+val into train and val (e.g., 75/25 of the 80%, giving 60/20/20 overall). Alternatively, use sklearn.model_selection.KFold for cross-validation.

Conclusion

train_test_split() is the foundation of every machine learning workflow in Python. Always use it before training -- never evaluate on training data. Set random_state for reproducibility, use stratify=y for imbalanced classes, and remember to fit preprocessing steps on the training set only. For model selection, split three ways (train/validation/test) or use cross-validation for more robust estimates.

📚