Contents
I've written a lot about linear models, where everything is about finding the right weights and drawing smooth boundaries. But not all problems can fit neatly into that framework.
Sometimes, patterns in your data can present themselves as rules where “if this feature is above some threshold AND that feature is below another threshold, then predict class A.”
Naturally, the next thing for us to explore is decision trees. They approach problems completely differently, thinking in terms of sequential conditions rather than weighted sums. And when you stack a bunch of them together into a random forest, you get one of the most reliable algorithms in ML.
Credit Where Credit is Due
What is a Decision Tree?
A decision tree is exactly what it sounds like: a tree of decisions. At each node, you ask a yes/no question about a feature. Based on the answer, you go left or right. You keep following questions until you reach a leaf node that gives you a prediction.
Interactive: Decision Tree Structure
Click "Build Tree" to animate
The beauty is interpretability. You can literally trace the path from root to leaf and explain exactly why the model made a specific prediction. Try doing that with a neural network. Decision trees work for both classification and regression, but we'll focus on classification since that's what our implementation handles.
The Core Algorithm: Growing a Tree
Building a decision tree is a recursive process. Here's the high-level algorithm:
Start with all your training data at the root
Find the best question to ask (best feature and threshold to split on)
Split the data based on that question
Recursively repeat for each child node
Stop when you hit a stopping condition
The magic is in step 2: how do we find the “best” question? To answer this, we need a way to measure how “mixed” a set of labels is.
Measuring Impurity: Entropy and Gini Index
To find the best split, we need to measure how “mixed” a set of labels is. A node with all the same class is pure (good). A node with a 50-50 mix of classes is impure (bad).
Entropy
Entropy comes from information theory and measures the amount of disorder or uncertainty in a set of labels. If you need to ask binary questions to identify the class, entropy tells you the average number of questions needed.
Sum of negative p times log₂(p) for each class proportion
Here, p is the proportion of samples belonging to each class. Let's understand this intuitively: if all samples are the same class (p = 1), entropy equals 0 (perfectly pure). If samples are evenly split between two classes (p = 0.5 each), entropy equals 1 (maximum impurity).
Gini Index
The Gini index is an alternative impurity measure. It measures the probability of incorrectly classifying a randomly chosen sample if you randomly label it according to the class distribution.
One minus sum of squared proportions
For binary classification: if all samples are the same class, Gini = 0 (pure). If samples are evenly split, Gini = 0.5 (maximum impurity).
Interactive: Impurity Measures Comparison
Both measures are 0 when pure and maximum when evenly split
→ Maximum impurity (50/50 split)
They're remarkably similar in practice. Entropy ranges from 0 to 1 (for binary), Gini from 0 to 0.5. Entropy is slightly more expensive to compute (logarithm vs square) and tends to favor balanced splits slightly more. Gini is faster and simpler.
| Metric | Range | Computation | Best For |
|---|---|---|---|
| Entropy | 0 to 1 | Uses log₂ (slower) | Balanced splits |
| Gini Index | 0 to 0.5 | Uses squares (faster) | General use (default) |
In Practice
Information Gain: Finding the Best Split
Now we can measure how good a split is using information gain. The idea: we want splits that reduce impurity the most. Start with mixed data, split it, and measure how much purer the resulting groups are.
Parent impurity minus weighted average of children impurity
The weighted average accounts for split size. If one child has 90% of the data and the other has 10%, we care more about the larger child's impurity.
Comparing Two Splits
Choose the split with higher information gain (more separation)
For each possible split, we try every feature, try every unique value in that feature as a threshold, calculate information gain for each combination, and pick the one with highest gain.
1def _best_split(self, X, y, feature_indices):2 best_gain = -13 best_feature = None4 best_threshold = None5 6 for feature_idx in feature_indices:7 X_column = X[:, feature_idx]8 thresholds = np.unique(X_column)9 10 for threshold in thresholds:11 gain = self._information_gain(y, X_column, threshold)12 13 if gain > best_gain:14 best_gain = gain15 best_feature = feature_idx16 best_threshold = threshold17 18 return best_feature, best_thresholdComputational Cost
The Recursive Tree-Building Process
1def _grow_tree(self, X, y, depth):2 n_samples, n_features = X.shape3 n_labels = len(np.unique(y))4 5 # Stopping conditions6 if (self.max_depth is not None and depth >= self.max_depth) or \7 n_labels == 1 or \8 n_samples < self.min_samples_split:9 leaf_value = self._most_common_label(y)10 return Node(value=leaf_value, impurity=self._calculate_impurity(y))11 12 # Find best split13 feature_indices = np.random.choice(n_features, self.max_features, replace=False)14 best_feature, best_threshold = self._best_split(X, y, feature_indices)15 16 if best_feature is None:17 leaf_value = self._most_common_label(y)18 return Node(value=leaf_value, impurity=self._calculate_impurity(y))19 20 # Split the data21 left_indices = X[:, best_feature] <= best_threshold22 right_indices = X[:, best_feature] > best_threshold23 24 # Check minimum leaf size25 if np.sum(left_indices) < self.min_samples_leaf or \26 np.sum(right_indices) < self.min_samples_leaf:27 leaf_value = self._most_common_label(y)28 return Node(value=leaf_value, impurity=self._calculate_impurity(y))29 30 # Recursively grow children31 left = self._grow_tree(X[left_indices], y[left_indices], depth + 1)32 right = self._grow_tree(X[right_indices], y[right_indices], depth + 1)33 34 return Node(feature=best_feature, threshold=best_threshold, 35 left=left, right=right, impurity=self._calculate_impurity(y))The stopping conditions are crucial forms of pre-pruning(stopping growth early) to prevent overfitting:
Shallow trees underfit (too simple), deep trees overfit (too complex). Finding the sweet spot is key.
If you require at least 20 samples to split, you prevent decisions based on tiny subsets that might just be noise.
Ensures each prediction is based on a reasonable number of samples, not just one or two outliers.
Tuning Order
Making Predictions: Tree Traversal
Once the tree is built, making predictions is straightforward traversal. If we're at a leaf, we simply return its value. Otherwise, check the feature and recurse through the tree.
1def _traverse_tree(self, x, node):2 if node.value is not None:3 return node.value4 5 if x[node.feature] <= node.threshold:6 return self._traverse_tree(x, node.left)7 return self._traverse_tree(x, node.right)This is why decision trees are so interpretable—you can trace the exact path and see which conditions led to the prediction.
Enter Random Forests: Ensemble Power
A single decision tree is interpretable but unstable. Change your training data slightly and the tree structure can change dramatically. The top split might use a completely different feature.
This high variance is a problem which we can mitigate by building many trees and averaging their predictions. This is the core idea behind random forests—an ensemble of decision trees, each trained on slightly different data/features, making predictions by majority vote.
How to be Random?
The first key ingredient is bootstrapping (also called bagging). For each tree, we create a bootstrap sample of our data—sampling with replacement. We draw n samples from our dataset of size n, but each sample can be selected multiple times. On average, each bootstrap contains ~63% unique samples.
1if self.bootstrap:2 indices = np.random.choice(len(X), len(X), replace=True)3 X_sample = X[indices]4 y_sample = y[indices]Interactive: Bootstrap Sampling
Notice how some samples appear multiple times while others may not appear at all
This creates diversity. Each tree sees a slightly different view of the data, so they make different mistakes. When we average their predictions, the mistakes cancel out.
Bagging helps, but we can do better. The second key ingredient is feature randomness. At each split, instead of considering all features, we randomly select a subset.
1feature_indices = np.random.choice(n_features, self.max_features, replace=False)2best_feature, best_threshold = self._best_split(X, y, feature_indices)| max_features | Value | Use Case |
|---|---|---|
| sqrt | √n features | Default for classification |
| log2 | log₂(n) features | More randomness |
| Integer | Exactly that many | Custom control |
| None | All features | Just bagging, no feature randomness |
Why Feature Randomness Helps
Training the Forest
Training a random forest is embarrassingly parallel—each tree is independent:
1def fit(self, X, y, verbose=False):2 X = np.array(X)3 y = np.array(y).flatten()4 5 self.n_classes_ = len(np.unique(y))6 self.n_features_ = X.shape[1]7 self.trees = []8 9 if self.random_state is not None:10 np.random.seed(self.random_state)11 12 for i in range(self.n_estimators):13 # Each tree gets a different random seed14 tree_seed = None if self.random_state is None else self.random_state + i15 tree = DecisionTreeClassifier(16 max_depth=self.max_depth,17 min_samples_split=self.min_samples_split,18 min_samples_leaf=self.min_samples_leaf,19 max_features=self.max_features,20 criterion=self.criterion,21 random_state=tree_seed22 )23 24 # Bootstrap sample25 if self.bootstrap:26 indices = np.random.choice(len(X), len(X), replace=True)27 X_sample = X[indices]28 y_sample = y[indices]29 else:30 X_sample = X31 y_sample = y32 33 tree.fit(X_sample, y_sample)34 self.trees.append(tree)We build n_estimators trees (typically 100-500), each on a bootstrap sample, each with feature randomness.
Making Predictions Through Majority Voting
For classification, prediction is simple majority voting:
1def predict(self, X):2 X = np.array(X)3 4 tree_predictions = np.array([tree.predict(X) for tree in self.trees])5 6 # For each sample, find the most common prediction7 predictions = []8 for i in range(X.shape[0]):9 votes = tree_predictions[:, i]10 most_common = Counter(votes).most_common(1)[0][0]11 predictions.append(most_common)12 13 return np.array(predictions)If you have 100 trees and 60 predict class 1 while 40 predict class 0, the forest predicts class 1. You can obviously change these thresholds in your code as you feel appropriate. We can also return probability estimates instead of a single class (0.6 probability of class 1).
Feature Importance
One downside of moving from one tree to a forest is losing interpretability. We can't trace a single path anymore. But we can measure which features matter most:
1def _calculate_feature_importances(self):2 importances = np.zeros(self.n_features_)3 4 for tree in self.trees:5 self._traverse_for_importance(tree.root, importances)6 7 if np.sum(importances) > 0:8 importances = importances / np.sum(importances)9 10 self.feature_importances_ = importances1112def _traverse_for_importance(self, node, importances):13 if node is None or node.value is not None:14 return15 16 importances[node.feature] += 117 self._traverse_for_importance(node.left, importances)18 self._traverse_for_importance(node.right, importances)Our simple implementation counts how often each feature is used for splits across all trees. More sophisticated versions weight by the information gain of each split. This gives you a sense of which features the forest relies on most, even if you can't trace individual predictions.
Beyond the Basics
Though we've built decision trees and random forests from scratch, here are various ways to extend and optimize:
Scaling not required
Unlike linear models, trees don't care about feature scales. A feature ranging 0-1 and one ranging 0-1000 are treated equally.
Feature interactions
Trees naturally capture feature interactions. The split "if income < 50k AND age > 30" happens automatically through tree structure.
Parallel training
Trees are independent and can be trained in parallel. Real implementations (like sklearn) do this for massive speedups.
Memory usage
Random forests use a lot of memory (storing many trees). For very large datasets, this can be a constraint.
Prediction speed
Forests are slower to predict than single trees (need to traverse many trees). For latency-critical applications, this matters.
Gradient Boosted Trees
Instead of training trees independently, train them sequentially where each tree corrects the previous trees' mistakes. XGBoost, LightGBM, and CatBoost are incredibly powerful implementations of this idea.
Post-Pruning
Instead of stopping growth early (pre-pruning), grow a full tree and then prune branches that don't improve validation performance. This can find better trees than pre-pruning alone.
Class Weights
Handle imbalanced datasets by giving more weight to minority class errors. When you have 95% negative examples, weighting minority class errors higher forces the model to care about getting those rare cases right.
Happy coding, and don't get lost in the forest.