Back to Lessons
supervisedclassificationregressionensemble

Decision Trees & Random Forests

From single trees to powerful forests. Ensemble learning at its finest.

Written byOmansh
9 min read
From Scratch
Source Code

I've written a lot about linear models, where everything is about finding the right weights and drawing smooth boundaries. But not all problems can fit neatly into that framework.

Sometimes, patterns in your data can present themselves as rules where “if this feature is above some threshold AND that feature is below another threshold, then predict class A.”

Naturally, the next thing for us to explore is decision trees. They approach problems completely differently, thinking in terms of sequential conditions rather than weighted sums. And when you stack a bunch of them together into a random forest, you get one of the most reliable algorithms in ML.

Credit Where Credit is Due

Before we begin, I want to give a shoutout to Normalized Nerd for his video that helped me understand this topic best.

What is a Decision Tree?

A decision tree is exactly what it sounds like: a tree of decisions. At each node, you ask a yes/no question about a feature. Based on the answer, you go left or right. You keep following questions until you reach a leaf node that gives you a prediction.

Interactive: Decision Tree Structure

Click "Build Tree" to animate

Age ≤ 30?n=100
Decision nodeLeaf node

The beauty is interpretability. You can literally trace the path from root to leaf and explain exactly why the model made a specific prediction. Try doing that with a neural network. Decision trees work for both classification and regression, but we'll focus on classification since that's what our implementation handles.

The Core Algorithm: Growing a Tree

Building a decision tree is a recursive process. Here's the high-level algorithm:

1

Start with all your training data at the root

2

Find the best question to ask (best feature and threshold to split on)

3

Split the data based on that question

4

Recursively repeat for each child node

5

Stop when you hit a stopping condition

The magic is in step 2: how do we find the “best” question? To answer this, we need a way to measure how “mixed” a set of labels is.

Measuring Impurity: Entropy and Gini Index

To find the best split, we need to measure how “mixed” a set of labels is. A node with all the same class is pure (good). A node with a 50-50 mix of classes is impure (bad).

Entropy

Entropy comes from information theory and measures the amount of disorder or uncertainty in a set of labels. If you need to ask binary questions to identify the class, entropy tells you the average number of questions needed.

H(S)=−Σplog₂(p)

Sum of negative p times log₂(p) for each class proportion

Here, p is the proportion of samples belonging to each class. Let's understand this intuitively: if all samples are the same class (p = 1), entropy equals 0 (perfectly pure). If samples are evenly split between two classes (p = 0.5 each), entropy equals 1 (maximum impurity).

Gini Index

The Gini index is an alternative impurity measure. It measures the probability of incorrectly classifying a randomly chosen sample if you randomly label it according to the class distribution.

Gini(S)=1 − Σp2

One minus sum of squared proportions

For binary classification: if all samples are the same class, Gini = 0 (pure). If samples are evenly split, Gini = 0.5 (maximum impurity).

Interactive: Impurity Measures Comparison

Both measures are 0 when pure and maximum when evenly split

Proportion of Class 1 (p)Impurity00.5100.51EntropyGini
Class AClass B
Entropy
1.0000
max = 1.0
Gini Index
0.5000
max = 0.5

→ Maximum impurity (50/50 split)

They're remarkably similar in practice. Entropy ranges from 0 to 1 (for binary), Gini from 0 to 0.5. Entropy is slightly more expensive to compute (logarithm vs square) and tends to favor balanced splits slightly more. Gini is faster and simpler.

MetricRangeComputationBest For
Entropy0 to 1Uses log₂ (slower)Balanced splits
Gini Index0 to 0.5Uses squares (faster)General use (default)

In Practice

The choice rarely matters. Most implementations default to Gini because it's faster, but entropy connects to information theory which some people find more intuitive. For our implementation, we support both depending on user configuration.

Information Gain: Finding the Best Split

Now we can measure how good a split is using information gain. The idea: we want splits that reduce impurity the most. Start with mixed data, split it, and measure how much purer the resulting groups are.

IG(S,split)=ImpurityparentWeighted Avgchildren

Parent impurity minus weighted average of children impurity

The weighted average accounts for split size. If one child has 90% of the data and the other has 10%, we care more about the larger child's impurity.

Comparing Two Splits

Split A
8:2
Left
3:7
Right
IG ≈ 0.12
Split B ✓
9:1
Left
1:9
Right
IG ≈ 0.47

Choose the split with higher information gain (more separation)

For each possible split, we try every feature, try every unique value in that feature as a threshold, calculate information gain for each combination, and pick the one with highest gain.

best_split.py
1def _best_split(self, X, y, feature_indices):
2 best_gain = -1
3 best_feature = None
4 best_threshold = None
5
6 for feature_idx in feature_indices:
7 X_column = X[:, feature_idx]
8 thresholds = np.unique(X_column)
9
10 for threshold in thresholds:
11 gain = self._information_gain(y, X_column, threshold)
12
13 if gain > best_gain:
14 best_gain = gain
15 best_feature = feature_idx
16 best_threshold = threshold
17
18 return best_feature, best_threshold

Computational Cost

This is computationally expensive! For n features and m unique values per feature, we're checking n × m splits. For large datasets with continuous features, this is why decision trees can be slow to train.

The Recursive Tree-Building Process

grow_tree.py
1def _grow_tree(self, X, y, depth):
2 n_samples, n_features = X.shape
3 n_labels = len(np.unique(y))
4
5 # Stopping conditions
6 if (self.max_depth is not None and depth >= self.max_depth) or \
7 n_labels == 1 or \
8 n_samples < self.min_samples_split:
9 leaf_value = self._most_common_label(y)
10 return Node(value=leaf_value, impurity=self._calculate_impurity(y))
11
12 # Find best split
13 feature_indices = np.random.choice(n_features, self.max_features, replace=False)
14 best_feature, best_threshold = self._best_split(X, y, feature_indices)
15
16 if best_feature is None:
17 leaf_value = self._most_common_label(y)
18 return Node(value=leaf_value, impurity=self._calculate_impurity(y))
19
20 # Split the data
21 left_indices = X[:, best_feature] <= best_threshold
22 right_indices = X[:, best_feature] > best_threshold
23
24 # Check minimum leaf size
25 if np.sum(left_indices) < self.min_samples_leaf or \
26 np.sum(right_indices) < self.min_samples_leaf:
27 leaf_value = self._most_common_label(y)
28 return Node(value=leaf_value, impurity=self._calculate_impurity(y))
29
30 # Recursively grow children
31 left = self._grow_tree(X[left_indices], y[left_indices], depth + 1)
32 right = self._grow_tree(X[right_indices], y[right_indices], depth + 1)
33
34 return Node(feature=best_feature, threshold=best_threshold,
35 left=left, right=right, impurity=self._calculate_impurity(y))

The stopping conditions are crucial forms of pre-pruning(stopping growth early) to prevent overfitting:

max_depth

Shallow trees underfit (too simple), deep trees overfit (too complex). Finding the sweet spot is key.

min_samples_split

If you require at least 20 samples to split, you prevent decisions based on tiny subsets that might just be noise.

min_samples_leaf

Ensures each prediction is based on a reasonable number of samples, not just one or two outliers.

Tuning Order

Generally, you want to tackle hyperparameter tuning in this order since max_depth generally has the greatest effect on the resulting tree.

Making Predictions: Tree Traversal

Once the tree is built, making predictions is straightforward traversal. If we're at a leaf, we simply return its value. Otherwise, check the feature and recurse through the tree.

traverse_tree.py
1def _traverse_tree(self, x, node):
2 if node.value is not None:
3 return node.value
4
5 if x[node.feature] <= node.threshold:
6 return self._traverse_tree(x, node.left)
7 return self._traverse_tree(x, node.right)

This is why decision trees are so interpretable—you can trace the exact path and see which conditions led to the prediction.

Enter Random Forests: Ensemble Power

A single decision tree is interpretable but unstable. Change your training data slightly and the tree structure can change dramatically. The top split might use a completely different feature.

This high variance is a problem which we can mitigate by building many trees and averaging their predictions. This is the core idea behind random forests—an ensemble of decision trees, each trained on slightly different data/features, making predictions by majority vote.

How to be Random?

The first key ingredient is bootstrapping (also called bagging). For each tree, we create a bootstrap sample of our data—sampling with replacement. We draw n samples from our dataset of size n, but each sample can be selected multiple times. On average, each bootstrap contains ~63% unique samples.

bootstrap.py
1if self.bootstrap:
2 indices = np.random.choice(len(X), len(X), replace=True)
3 X_sample = X[indices]
4 y_sample = y[indices]

Interactive: Bootstrap Sampling

Notice how some samples appear multiple times while others may not appear at all

Original Dataset (n=10)
D1
D2
D3
D4
D5
D6
D7
D8
D9
D10
Click "Generate Samples" to create bootstrap samples for 3 trees

This creates diversity. Each tree sees a slightly different view of the data, so they make different mistakes. When we average their predictions, the mistakes cancel out.

Bagging helps, but we can do better. The second key ingredient is feature randomness. At each split, instead of considering all features, we randomly select a subset.

feature_selection.py
1feature_indices = np.random.choice(n_features, self.max_features, replace=False)
2best_feature, best_threshold = self._best_split(X, y, feature_indices)
max_featuresValueUse Case
sqrt√n featuresDefault for classification
log2log₂(n) featuresMore randomness
IntegerExactly that manyCustom control
NoneAll featuresJust bagging, no feature randomness

Why Feature Randomness Helps

Without it, if one feature is very strong, all trees will use it for the top split. The trees become correlated. Their predictions are similar, so averaging doesn't help much. With feature randomness, trees are forced to find alternative splits, making their collective wisdom more valuable.

Training the Forest

Training a random forest is embarrassingly parallel—each tree is independent:

random_forest_fit.py
1def fit(self, X, y, verbose=False):
2 X = np.array(X)
3 y = np.array(y).flatten()
4
5 self.n_classes_ = len(np.unique(y))
6 self.n_features_ = X.shape[1]
7 self.trees = []
8
9 if self.random_state is not None:
10 np.random.seed(self.random_state)
11
12 for i in range(self.n_estimators):
13 # Each tree gets a different random seed
14 tree_seed = None if self.random_state is None else self.random_state + i
15 tree = DecisionTreeClassifier(
16 max_depth=self.max_depth,
17 min_samples_split=self.min_samples_split,
18 min_samples_leaf=self.min_samples_leaf,
19 max_features=self.max_features,
20 criterion=self.criterion,
21 random_state=tree_seed
22 )
23
24 # Bootstrap sample
25 if self.bootstrap:
26 indices = np.random.choice(len(X), len(X), replace=True)
27 X_sample = X[indices]
28 y_sample = y[indices]
29 else:
30 X_sample = X
31 y_sample = y
32
33 tree.fit(X_sample, y_sample)
34 self.trees.append(tree)

We build n_estimators trees (typically 100-500), each on a bootstrap sample, each with feature randomness.

Making Predictions Through Majority Voting

For classification, prediction is simple majority voting:

predict.py
1def predict(self, X):
2 X = np.array(X)
3
4 tree_predictions = np.array([tree.predict(X) for tree in self.trees])
5
6 # For each sample, find the most common prediction
7 predictions = []
8 for i in range(X.shape[0]):
9 votes = tree_predictions[:, i]
10 most_common = Counter(votes).most_common(1)[0][0]
11 predictions.append(most_common)
12
13 return np.array(predictions)

If you have 100 trees and 60 predict class 1 while 40 predict class 0, the forest predicts class 1. You can obviously change these thresholds in your code as you feel appropriate. We can also return probability estimates instead of a single class (0.6 probability of class 1).

Feature Importance

One downside of moving from one tree to a forest is losing interpretability. We can't trace a single path anymore. But we can measure which features matter most:

feature_importance.py
1def _calculate_feature_importances(self):
2 importances = np.zeros(self.n_features_)
3
4 for tree in self.trees:
5 self._traverse_for_importance(tree.root, importances)
6
7 if np.sum(importances) > 0:
8 importances = importances / np.sum(importances)
9
10 self.feature_importances_ = importances
11
12def _traverse_for_importance(self, node, importances):
13 if node is None or node.value is not None:
14 return
15
16 importances[node.feature] += 1
17 self._traverse_for_importance(node.left, importances)
18 self._traverse_for_importance(node.right, importances)

Our simple implementation counts how often each feature is used for splits across all trees. More sophisticated versions weight by the information gain of each split. This gives you a sense of which features the forest relies on most, even if you can't trace individual predictions.

Beyond the Basics

Though we've built decision trees and random forests from scratch, here are various ways to extend and optimize:

1

Scaling not required

Unlike linear models, trees don't care about feature scales. A feature ranging 0-1 and one ranging 0-1000 are treated equally.

2

Feature interactions

Trees naturally capture feature interactions. The split "if income < 50k AND age > 30" happens automatically through tree structure.

3

Parallel training

Trees are independent and can be trained in parallel. Real implementations (like sklearn) do this for massive speedups.

4

Memory usage

Random forests use a lot of memory (storing many trees). For very large datasets, this can be a constraint.

5

Prediction speed

Forests are slower to predict than single trees (need to traverse many trees). For latency-critical applications, this matters.

6

Gradient Boosted Trees

Instead of training trees independently, train them sequentially where each tree corrects the previous trees' mistakes. XGBoost, LightGBM, and CatBoost are incredibly powerful implementations of this idea.

7

Post-Pruning

Instead of stopping growth early (pre-pruning), grow a full tree and then prune branches that don't improve validation performance. This can find better trees than pre-pruning alone.

8

Class Weights

Handle imbalanced datasets by giving more weight to minority class errors. When you have 95% negative examples, weighting minority class errors higher forces the model to care about getting those rare cases right.

Happy coding, and don't get lost in the forest.