Implementing Decision Trees from Scratch
Implementing Decision Trees from Scratch
In this tutorial, you'll build a decision tree classifier from the ground up using only NumPy. By the end, you'll understand how decision trees actually work. No library magic, just the math and the logic.
Prerequisites: Basic Python and a little NumPy (creating arrays, indexing).
What we'll cover:
- What decision trees are and how they make predictions
- Measuring how "mixed" a node is
- Splitting data on features
- Choosing the best split using information gain
- Building the tree recursively
- Training on a real dataset and evaluating accuracy
Let's get started.
What is a decision tree?
A decision tree makes predictions by asking a series of yes/no questions about the input. Each internal node tests one feature; each branch is a possible answer; each leaf is a final prediction.
For example, imagine deciding whether to bring an umbrella:
- Is it raining? → If yes, bring umbrella. If no, check the next question.
- Are dark clouds present? → If yes, bring umbrella. If no, leave it.
Each question splits the data into smaller groups. The tree keeps splitting until each group is pure enough to make a confident prediction.
The challenge: how do we decide which question to ask at each node? That's what most of this tutorial is about.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patchesMeasuring impurity
To decide where to split, we need a way to measure how "mixed" the labels in a group are. A pure group has all one class; an impure group has labels mixed together.
We'll use two common impurity measures:
Entropy: from information theory. Measures the unpredictability of the labels.
Where is the fraction of positive examples. Entropy is 0 when the group is pure and 1 when classes are exactly split 50/50.
Gini index: measures the probability of misclassifying a randomly chosen example.
Both behave similarly. Gini is faster to compute (no log), entropy has cleaner information-theoretic meaning. We'll implement both and use them interchangeably.
def compute_entropy(Y):
"""
Compute the entropy of a binary label array.
Entropy measures the impurity/disorder of a node. It's 0 when all labels
are the same (pure) and 1 when classes are perfectly balanced (max disorder).
Formula: H(p) = -p log2(p) - (1-p) log2(1-p)
Args:
Y (np.ndarray): Array of binary labels (0s and 1s)
Returns:
float: Entropy value in [0, 1]
"""
# Empty node has no impurity
if len(Y) == 0:
return 0
# Fraction of positive examples
p = sum(Y) / len(Y)
# If all examples are one class, entropy is 0 (pure node)
# We check this to avoid log(0) which is undefined
if p == 0 or p == 1:
return 0
return -p * np.log2(p) - (1 - p) * np.log2(1 - p)def compute_gini_index(Y):
"""
Compute the Gini index of a binary label array.
Gini measures the probability of misclassifying a randomly chosen example
if labeled randomly according to the distribution. It's 0 for pure nodes
and 0.5 for perfectly balanced binary classification.
Formula: G(p) = 1 - p^2 - (1-p)^2
Args:
Y (np.ndarray): Array of binary labels (0s and 1s)
Returns:
float: Gini index in [0, 0.5] for binary classification
"""
if len(Y) == 0:
return 0
# Fraction of positive examples
p = sum(Y) / len(Y)
return 1 - p**2 - (1 - p)**2Splitting the dataset
Given a feature, we split the data into two groups: examples where the feature is 1 (left), and examples where it's 0 (right).
We track examples by their indices in the dataset rather than copying data around. This is more efficient and lets us refer back to the original data anytime.
def split_dataset(X, node_indices, feature_idx):
"""
Split the dataset on a specific feature.
Examples where the feature value is 1 go to the left child;
examples where it's 0 go to the right child.
We track examples by their original indices instead of copying data.
This is more memory-efficient and lets us refer back to the original
dataset throughout the tree.
Args:
X (np.ndarray): The full feature matrix, shape (n_samples, n_features)
node_indices (list): Indices of examples in the current node
feature_idx (int): Which feature to split on
Returns:
tuple: (left_indices, right_indices) — lists of example indices
"""
left_indices = []
right_indices = []
# Walk through each example at this node and assign it to left or right
for i in node_indices:
if X[i][feature_idx] == 1:
left_indices.append(i)
else:
right_indices.append(i)
return left_indices, right_indicesInformation gain
When we split a node, impurity should decrease. The amount of decrease is information gain:
Where and are the fractions of examples going to each child. We pick the split that gives the highest information gain.
The function below works with either impurity measure (entropy or Gini) by accepting it as a parameter.
def compute_information_gain(X, Y, node_indices, feature_idx, split_criterion):
"""
Compute the information gain from splitting on a feature.
Information gain measures how much impurity decreases after splitting.
Higher gain means the split produces purer child nodes — which is what
we want. The feature with the highest gain is the best one to split on.
Formula:
IG = H(parent) - [w_left * H(left) + w_right * H(right)]
Where w_left and w_right are the fractions of examples going to each side.
Args:
X (np.ndarray): Full feature matrix
Y (np.ndarray): Full label array
node_indices (list): Indices at the current node
feature_idx (int): Feature to evaluate
split_criterion (callable): Either compute_entropy or compute_gini_index
Returns:
float: Information gain (always >= 0)
"""
# Split the examples on this feature
left_indices, right_indices = split_dataset(X, node_indices, feature_idx)
# Extract labels for the parent and each child
Y_node = Y[node_indices]
Y_left = Y[left_indices]
Y_right = Y[right_indices]
# Number of examples at this node (used for weighting)
n = len(node_indices)
# Impurity before splitting
h_node = split_criterion(Y_node)
# Weighted impurity after splitting (each side weighted by its size)
w_left = len(left_indices) / n
w_right = len(right_indices) / n
weighted_impurity = w_left * split_criterion(Y_left) + w_right * split_criterion(Y_right)
# The reduction in impurity is the information gain
return h_node - weighted_impurityChoosing the best split
At each node, we evaluate every feature and pick the one with the highest information gain. We also skip features that don't actually split the data (all examples go to one side).
def get_best_split(X, Y, node_indices, feature_names, split_criterion):
"""
Find the feature that gives the highest information gain.
For each feature, we:
1. Split the data on it
2. Skip if the split doesn't actually divide the data (all to one side)
3. Compute information gain
4. Track the best
Args:
X (np.ndarray): Feature matrix
Y (np.ndarray): Labels
node_indices (list): Indices at the current node
feature_names (list): Names of features (for indexing)
split_criterion (callable): Impurity function to use
Returns:
tuple: (best_feature_idx, max_info_gain)
Returns (-1, 0) if no useful split exists
"""
best_idx = -1
max_info_gain = 0
# Try splitting on every feature
for feature_idx in range(len(feature_names)):
left_indices, right_indices = split_dataset(X, node_indices, feature_idx)
# Skip features where all examples go to one side — that's not a real split
if len(left_indices) == 0 or len(right_indices) == 0:
continue
# Compute how good this split is
info_gain = compute_information_gain(X, Y, node_indices, feature_idx, split_criterion)
# Track the best one
if info_gain > max_info_gain:
max_info_gain = info_gain
best_idx = feature_idx
return best_idx, max_info_gainBuilding the tree recursively
Now we put it together. Starting from the root with all examples, we:
- Find the best feature to split on
- Split the data into left and right groups
- Recursively build a tree for each side
We stop splitting when:
- The node is pure (all one class)
- No useful split exists (information gain ≤ 0)
- We've reached the maximum depth
This recursive structure is what gives decision trees their characteristic tree shape.
def build_tree(X, Y, node_indices, feature_names, split_criterion, max_depth=32, current_depth=0):
"""
Recursively build a decision tree.
Returns a nested dictionary where each node is either:
- A leaf with a prediction
- A split with 'left' and 'right' child trees
Args:
X (np.ndarray): Feature matrix
Y (np.ndarray): Labels
node_indices (list): Indices of examples at this node
feature_names (list): Names of features (for readability)
split_criterion (callable): compute_entropy or compute_gini_index
max_depth (int): Maximum tree depth
current_depth (int): Current depth (internal use)
Returns:
dict: The tree structure
"""
Y_node = Y[node_indices]
# Stopping conditions → create a leaf
if (current_depth == max_depth or
len(Y_node) == 0 or
len(np.unique(Y_node)) == 1):
prediction = int(round(np.mean(Y_node))) if len(Y_node) > 0 else 0
return {
'type': 'leaf',
'prediction': prediction,
'samples': len(Y_node)
}
# Find the best feature to split on
best_idx, info_gain = get_best_split(X, Y, node_indices, feature_names, split_criterion)
# No useful split → leaf
if info_gain <= 0:
prediction = int(round(np.mean(Y_node)))
return {
'type': 'leaf',
'prediction': prediction,
'samples': len(Y_node)
}
# Split the data
left_indices, right_indices = split_dataset(X, node_indices, best_idx)
# Empty split → leaf
if len(left_indices) == 0 or len(right_indices) == 0:
prediction = int(round(np.mean(Y_node)))
return {
'type': 'leaf',
'prediction': prediction,
'samples': len(Y_node)
}
# Recurse to build subtrees
return {
'type': 'split',
'feature_idx': best_idx,
'feature_name': feature_names[best_idx],
'left': build_tree(X, Y, left_indices, feature_names, split_criterion, max_depth, current_depth + 1),
'right': build_tree(X, Y, right_indices, feature_names, split_criterion, max_depth, current_depth + 1)
}Print tree
Let's write a function that will print a minimal visualization of the tree
def print_tree(tree, indent=0, prefix=""):
"""Pretty-print a decision tree as ASCII."""
pad = " " * indent
if tree['type'] == 'leaf':
marker = "✓" if tree['prediction'] == 1 else "✗"
print(f"{pad}{prefix}{marker} Predict {tree['prediction']} (n={tree['samples']})")
else:
gain_str = f" (gain={tree['info_gain']:.4f})" if 'info_gain' in tree else ""
print(f"{pad}{prefix}[{tree['feature_name']}]{gain_str}")
print_tree(tree['left'], indent + 1, "├─ if 1: ")
print_tree(tree['right'], indent + 1, "└─ if 0: ")Try it on a small example
Let's start with a toy dataset: classifying animals as cats (1) or not-cats (0) based on three binary features.
X_toy = np.array([
[1, 1, 1],
[0, 0, 1],
[0, 1, 0],
[1, 0, 1],
[1, 1, 1],
[1, 1, 0],
[0, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 1, 0]
])
Y_toy = np.array([1, 1, 0, 0, 1, 1, 0, 1, 0, 0])
feature_names = ['Ear Shape', 'Face Shape', 'Whiskers']
root_indices = list(range(len(X_toy)))
print("Building tree with entropy:")
tree_toy = build_tree(X_toy, Y_toy, root_indices, feature_names, compute_entropy, max_depth=2)
print_tree(tree_toy)Training on a real dataset
The toy example is useful for understanding, but let's see how this works on a more realistic dataset. We'll generate a binary classification dataset, split it into training and testing sets, train the tree, and measure accuracy.
For our tree to work with this data, all features need to be binary (0 or 1). We'll create synthetic data that's already in this form.
np.random.seed(42)
# Generate a synthetic binary classification dataset
n_samples = 500
n_features = 6
# Random binary features
X = np.random.randint(0, 2, size=(n_samples, n_features))
# Create a target based on a rule the tree should be able to learn
# y = 1 if (feature 0 AND feature 1) OR (feature 2 AND NOT feature 3)
Y = ((X[:, 0] & X[:, 1]) | (X[:, 2] & ~X[:, 3] & 1)).astype(int)
# Add some noise (flip 5% of labels)
noise_mask = np.random.rand(n_samples) < 0.05
Y[noise_mask] = 1 - Y[noise_mask]
print(f"Dataset shape: {X.shape}")
print(f"Class balance: {Y.mean():.2%} positive")# Shuffle and split into 80% train, 20% test
indices = np.random.permutation(n_samples)
split_point = int(0.8 * n_samples)
train_idx = indices[:split_point]
test_idx = indices[split_point:]
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]
print(f"Training set: {len(X_train)} examples")
print(f"Test set: {len(X_test)} examples")feature_names = [f'feature_{i}' for i in range(n_features)]
train_root = list(range(len(X_train)))
print("Building tree on training data:\n")
tree = build_tree(X_train, Y_train, train_root, feature_names, compute_entropy, max_depth=5)
print_tree(tree)Making predictions
Building the tree gave us a list of splits. To make a prediction on a new example, we need to walk through the tree from the root, following the appropriate branch at each split, until we reach a leaf.
def predict_one(tree, x):
"""
Predict the class of a single example by walking the tree.
Starts at the root and follows branches based on feature values until
reaching a leaf, where the leaf's stored prediction is returned.
Args:
tree (dict): A tree structure built by build_tree_structured
x (np.ndarray): A single example's feature vector
Returns:
int: The predicted class (0 or 1)
"""
# Base case: leaf node, return its prediction
if tree['type'] == 'leaf':
return tree['prediction']
# Internal node: recurse into the appropriate child based on the feature value
if x[tree['feature_idx']] == 1:
return predict_one(tree['left'], x)
else:
return predict_one(tree['right'], x)
def predict(tree, X):
"""
Predict classes for an array of examples.
Just applies predict_one to each example.
Args:
tree (dict): A tree structure from build_tree_structured
X (np.ndarray): Feature matrix, shape (n_samples, n_features)
Returns:
np.ndarray: Predicted classes
"""
return np.array([predict_one(tree, x) for x in X])Evaluating the tree
Now we train on the training set, predict on the test set, and measure accuracy.
train_root = list(range(len(X_train)))
tree_model = build_tree(X_train, Y_train, train_root, feature_names, compute_entropy, max_depth=5)
# Predictions
train_preds = predict(tree_model, X_train)
test_preds = predict(tree_model, X_test)
# Accuracy
train_acc = (train_preds == Y_train).mean()
test_acc = (test_preds == Y_test).mean()
print(f"Training accuracy: {train_acc:.4f}")
print(f"Test accuracy: {test_acc:.4f}")tree_entropy = build_tree(X_train, Y_train, train_root, feature_names, compute_entropy, max_depth=5)
tree_gini = build_tree(X_train, Y_train, train_root, feature_names, compute_gini_index, max_depth=5)
acc_entropy = (predict(tree_entropy, X_test) == Y_test).mean()
acc_gini = (predict(tree_gini, X_test) == Y_test).mean()
print(f"Test accuracy with entropy: {acc_entropy:.4f}")
print(f"Test accuracy with Gini: {acc_gini:.4f}")How does max depth affect performance?
A deeper tree fits the training data better but may overfit. Let's plot training and test accuracy for different depths.
depths = range(1, 11)
train_accs = []
test_accs = []
for d in depths:
tree_d = build_tree(X_train, Y_train, train_root, feature_names, compute_entropy, max_depth=d)
train_accs.append((predict(tree_d, X_train) == Y_train).mean())
test_accs.append((predict(tree_d, X_test) == Y_test).mean())
plt.plot(depths, train_accs, marker='o', label='Training accuracy')
plt.plot(depths, test_accs, marker='s', label='Test accuracy')
plt.xlabel('Max depth')
plt.ylabel('Accuracy')
plt.title('Tree depth vs accuracy')
plt.legend()
plt.grid(alpha=0.3)
plt.show()What we built
In this tutorial, we implemented a decision tree classifier from scratch:
- Impurity measures(entropy, Gini): to quantify how mixed labels are at each node
- Splitting logic: to partition data based on feature values
- Information gain: to score and rank candidate splits
- Recursive tree construction: to build the tree top-down
- Prediction: to traverse the tree for new examples
- Evaluation: to measure accuracy on a held-out test set
The same principles power production tools like sklearn's DecisionTreeClassifier, but with optimizations for speed and additional features like continuous-valued splits and pruning.
Limitations of this implementation:
- Only handles binary features (extending to continuous requires trying threshold splits)
- No pruning (we rely on
max_depthas our only regularization) - Trees are sensitive to training data, small changes produce different trees
Next steps to explore:
- Add support for continuous features by trying threshold splits
- Try a real dataset (UCI Mushroom, breast cancer, etc.)
- Implement a random forest by training many trees on bootstrap samples
- Add tree visualization
- Compare with
sklearn.tree.DecisionTreeClassifieron the same data
Decision trees alone are a starting point. The real power comes from tree ensembles like random forests and gradient boosting (XGBoost), which combine many trees to dramatically improve accuracy.