Binary search tree

1. Problem statement

Implement a binary search tree (BST) (Sedgewick & Wayne, 2011, p. 396).

2. Insights

2.1. Sorted binary trees

Binary search trees are basically binary trees with one major difference: instead of manually controlling exactly where nodes should go in the structure of the tree (during insertion), we give up this control in exchange for predictable (and typically very fast) performance.

How do we give up control? We assign a key to every value we want to store in the BST, and let the BST decide where to store the key/value pair. The key's type must be comparable (can be greater or less than other keys). Then whenever we want to add a particular key (and associated value) into the BST, we start at the root node and see if the key is smaller than the key at the root. If so, we go down to the left subtree. If not, we go down the right subtree. If it is equal to the key at the root, we replace it with the new value. Then we do the comparison again until we arrive at a leaf node's child (when no more comparisons can be made to an existing key in the BST).

If you perform an in-order traversal of a BST, you get back all the keys in sorted order.

In a large BST with many children, each time we go down one level into the tree we eliminate about half of the search space. This is because, if the tree is roughly balanced, we'll only need to do \(\log_2N\) comparisons to find the right key in a BST with \(N\) nodes. This is a powerful property, pretty much identical to how binary search works for sorted arrays. Indeed, the BST is already sorted — the way in which we insert new nodes described in the above paragraph ensures that the BST maintains its sorted nature.

2.2. Insertion order determines structure

Let's consider the keys [3, 4, 1, 5, 2]. The shape of the final BST will depend on the insertion order.

t = BinarySearchTree()
t.insert(3) # Empty tree, so 3 becomes root node.
t.insert(4) # 4 is greater than 3, so insert to right child of 3.
t.insert(1) # 1 becomes 3's left child.
t.insert(5) # 5 becomes 4's right child.
t.insert(2) # 2 becomes 1's right child.

Pictorially it will look like this:

binary_search_tree_random_order_insertion.svg

If we insert them in order, we only end up using the right child at each level of tree:

t = BinarySearchTree()
t.insert(1) # Empty tree, so 1 becomes root node.
t.insert(2) # 2 is greater than 1, so insert to right child of 1.
t.insert(3) # 3 becomes 2's child (3 is greater than both 1 and 2).
t.insert(4) # 4 becomes 3's child (3 is greater than 1, 2, and 3).
t.insert(5) # 5 becomes 4's child (5 is greater than 1, 2, 3, and 4).

binary_search_tree_sorted_order_insertion.svg

This means that we have to make sure that the keys we insert into the BST are not in sorted order if we want to take advantage of the \(\log_2N\) lookup speed.

If the input is known to be random, then of course this pathological case can be ignored. If we know that the input can be non-random (but don't want to bother with messing with how the keys are fed into the BST), then self-balancing binary search trees may be more appropriate because they are resistant against sorted input by moving nodes around (if necessary) during each insertion operation, to ensure that we don't end up with a structure resembling a linked list.

3. Solution

We use two classes. One for the nodes (Node) and another for the overall API around binary trees (BinarySearchTree). Splitting things up this way makes the (recursive) insertion method easy.

The algorithms presented here are taken largely from Algorithms (Sedgewick & Wayne, 2011, p. 396).

from collections import deque

class Node:
    node_class_methods
class BinarySearchTree:
    binary_search_tree_class_methods

3.1. Initialization

3.1.1. Node initialization

The initialization purposely leaves out initialization of the left and right subtrees. This way, we force users to use our API in order to guarantee that we construct the tree correctly.

def __init__(self, key: int, val: Any=None):
    self.key = key
    self.val = val

Note that earlier we said that the key is comparable and is associated with a value. Like a hashmap, our binary tree cannot hold duplicate keys (such that when we store a new value for the same key, we overwrite the value in the existing key).

We set the left and right links to None. These point to other Node instances. Later when we add more nodes into this tree, these links will become populated.

    self.left: Optional[Node] = None
    self.right: Optional[Node] = None

Set the count to just 1. We'll update this when we insert new nodes into this tree.

    self.count = 1

3.1.2. BST initialization

A new binary tree is empty, so it doesn't require much. The only optional input is what the new root node will look like (if the caller has already constructed such a node).

def __init__(self, root: Optional[Node]=None):
    self.root = root

3.2. Insertion

Insertion requires us to compare the given key with what we already have in the tree. If the key does not exist, we have to add it to the correct spot. If it exists, we overwrite the existing entry.

The "correct spot" is to put the key/value into the left subtree if it's smaller than the current (root) key. It goes into the right subtree if greater. If the key is the same as the current key, replace it. This selection process is done recursively as many times as necessary until we end up replacing an existing node or adding a new one (left or right subtree is empty).

For simplicity, we use integers for the keys.

def insert(self, key: int, val: Any=None):
    self.root = self._insert(self.root, key, val)

def _insert(self, x: Optional[Node], key: int, val: Any=None):
    if x is None:
        return Node(key, val)
    if key < x.key:
        x.left = self._insert(x.left, key, val)
    elif key > x.key:
        x.right = self._insert(x.right, key, val)
    else:
        x.val = val
    x.count = self._size(x.left) + self._size(x.right) + 1
    return x

Note that we have two methods, insert() and _insert(). The second (private) one is recursive and calls itself as many times as necessary to check for the existence of the given key. The first (public) one simply starts off this recursive search using the current root node of the tree.

3.2.1. Convenient integer-only insertion

While our default insert() implementation above accounts for having distinct keys and values, for our purposes of demonstrating how the various binary tree algorithms behave we don't really care about what the values look like (everything is based on the keys). So to that end, we define a insert_int() method that only allows insertion of integers; it sets the value to a string representation of the key with a val= prefix to drive home the point that the values are distinct from the keys.

def insert_int(self, *args: int):
    for i, arg in enumerate(args):
        self.root = self._insert(self.root, arg, f"val={arg}")

3.3. Size

Just like we did for binary trees, getting the size is just a field lookup. During insertion and deletion we make sure that the count field is updated.

def size(self):
    return self._size(self.root)
def _size(self, x: Optional[Node]):
    if x is None:
        return 0
    return x.count

3.4. Lookup

Lookup is almost identical to insertion — we recursively check for the existence of the given key. And just like for insertion, we have to start off the recursive search with the root of the tree.

def lookup(self, key: int):
    return self._lookup(self.root, key)

def _lookup(self, x: Optional[Node], key: int):
    if x is None:
        return None
    if key < x.key:
        return self._lookup(x.left, key)
    elif key > x.key:
        return self._lookup(x.right, key)
    else:
        return x.val

3.5. Deletion

Deletion is tricky, for the same reasons that we had for binary trees.

When we want to delete a node, we have to first search for the node we want to delete. Then once we find the node we want to delete, we again have 3 cases:

  1. the node to delete has 0 children (leaf node),
  2. the node to delete has 1 child, or
  3. the node to delete has 2 children.

If the node we want to delete is a leaf node, there's no extra work. If it has a single child, we just make that child the replacement. If it has 2 children, we have to pick a child to be the replacement.

The "successor" s of a key k in a BST is a key with the smallest value that is still greater than k. In other words, if we were to print all the keys in order, s would immediately follow after k.

Similarly, the "predecessor" p is the largest key that is still smaller than k, and would be printed just before k.

For plain binary trees, we decided to choose the leftmost leaf node as the replacement. For BSTs, we choose the node with the smallest key in the right child. This is also known as the "successor", because it's guaranteed to be the next-largest key in the BST after the to-be-deleted node. Due to the way BSTs are already ordered, choosing this child as the replacement preserves the order in the BST (such that after the deletion, the tree remains as a proper BST). This is because the smallest key in the right child is still larger than the key in the left child.

When we are deleting a node with 2 children, it is advised in Sedgewick & Wayne (2011, p. 410), to choose randomly between the successor and prdecessor as the replacement (it's more "balanced" and thus leads to better tree structure over time). For our purposes we just stick to choosing the successor because it is predictable, allowing us to test more easily in unit tests.

def delete(self, key: int):
    self.root = self._delete(self.root, key)
def _delete(self, x: Optional[Node], key: int):
    # If there is no node to delete (because we could not find the node we
    # wanted to delete), then deletion is a NOP.
    if x is None:
        return None
    # If the key doesn't match for the current node, keep searching.
    if key < x.key:
        x.left = self._delete(x.left, key)
    # Ditto.
    elif key > x.key:
        x.right = self._delete(x.right, key)
    # Found the node we want to delete.
    else:
        # If either the left or right child is None, return the other child. If
        # the other child is None, then this means that this node had no
        # children. If the other child is not None, this means this node only
        # had 1 child (which means we're done).
        if x.left is None:
            return x.right
        if x.right is None:
            return x.left

        # We have two children. We need to pick a replacement (smallest key
        # greater than the current key).

        # Save a link to the node to be deleted, because we want links to the
        # original children of x.
        y = x

        # Get the successor node with the lowest key greater than x.key.
        x = self._min(x.right)

        # We can't just do "x.right = y.right" because y.right contains x
        # (self._min(x.right) is a read operation). So delete the successor out
        # of the right subtree, and assign this pruned subtree to be the
        # successor's right child.
        x.right = self._delete_min(y.right)

        # The original left child of (the now-deleted) x is now the left child
        # of x's successor. We have to do this operation last because otherwise
        # this left child interferes with the algorithm in _delete_min().
        x.left = y.left

    x.count = self._size(x.left) + self._size(x.right) + 1

    return x

def min(self) -> Optional[int]:
    if self.root is None:
        return None
    x = self._min(self.root)
    if x is None:
        return None
    return x.key

# Find the node with the smallest key in the given tree, rooted at node x.
def _min(self, x: Node) -> Node:
    while True:
        # If we have nothing smaller than x, this is the minimum.
        if x.left is None:
            break
        x = x.left
    return x

def delete_min(self):
    self.root = self._delete_min(self.root)

# Return a tree rooted at node x, but with the node containing the smallest key
# in it removed from this tree.
def _delete_min(self, x: Optional[Node]) -> Optional[Node]:
    if x is None:
        return None
    if x.left is None:
        return x.right
    x.left = self._delete_min(x.left)
    x.count = self._size(x.left) + self._size(x.right) + 1
    return x

3.6. Traversal

The code here is identical to the code for binary trees. See the discussion around traversal there.

def traverse_preorder(self, func: Callable[[Node], None]):
    return self._traverse_preorder(self.root, func)

def _traverse_preorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    func(x)
    self._traverse_preorder(x.left, func)
    self._traverse_preorder(x.right, func)

def traverse_inorder(self, func: Callable[[Node], None]):
    return self._traverse_inorder(self.root, func)

def _traverse_inorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    self._traverse_inorder(x.left, func)
    func(x)
    self._traverse_inorder(x.right, func)

def traverse_postorder(self, func: Callable[[Node], None]):
    return self._traverse_postorder(self.root, func)

def _traverse_postorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    self._traverse_postorder(x.left, func)
    self._traverse_postorder(x.right, func)
    func(x)

3.6.1. BFS

def bfs(self, func: Callable[[Node], None]):
    return self._bfs(self.root, func)

def _bfs(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return

    nodes_at_current_depth = [x]
    while nodes_at_current_depth:
        # Process all nodes at current depth.
        for node in nodes_at_current_depth:
            func(node)

        # Now add all nodes at the next depth.
        children = []
        for node in nodes_at_current_depth:
            if not node:
                continue
            if node.left:
                children.append(node.left)
            if node.right:
                children.append(node.right)

        # Repeat the loop at the next depth.
        nodes_at_current_depth = children
def bfs_single_pass(self, func: Callable[[Node], None]):
    return self._bfs(self.root, func)

def _bfs_single_pass(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return

    q = deque([x])
    while q:
        # Process the head of the queue. As we process each one, just before we
        # discard it we check if it has children, and if so, add them to the end
        # of the queue.

        node = q.popleft()
        if not node:
            continue

        func(node)

        # Add this node's children, if any.
        if node.left:
            q.append(node.left)
        if node.right:
            q.append(node.right)

4. Tests

🔗
from __future__ import annotations
from hypothesis import given, strategies as st
import unittest

from .binary_search_tree import BinarySearchTree, Node

class Test(unittest.TestCase):
    test_cases

if __name__ == "__main__":
    unittest.main(exit=False)

4.1. Initialization

def test_init_empty(self):
    t = BinarySearchTree()
    self.assertEqual(0, t.size())

def test_init_nonempty(self):
    root = Node(50, "a")
    t = BinarySearchTree(root)
    self.assertEqual(1, t.size())

4.2. Insertion

def test_insertion(self):
    t = BinarySearchTree()
    self.assertEqual(0, t.size())
    t.insert(50, "foo")
    self.assertEqual(1, t.size())
    # Inserting the same key (different value) results in the same size, because
    # the old node with the same key is replaced.
    t.insert(50, "bar")
    self.assertEqual(1, t.size())
    # Inserting a different key grows the tree by 1 node.
    t.insert_int(51)
    self.assertEqual(2, t.size())
    t.insert_int(52)
    self.assertEqual(3, t.size())

4.3. Lookup

def test_lookup(self):
    t = BinarySearchTree()
    self.assertEqual(t.lookup(5), None)
    t.insert_int(100)
    self.assertEqual(t.lookup(100), "val=100")
    # Manual insertion (bypassing insert_int()) allows us to set the value
    # directly.
    t.insert(100, "b")
    self.assertEqual(t.lookup(100), "b")

4.4. Deletion

def test_deletion(self):
    # Deletion on an empty tree is a NOP.
    t = BinarySearchTree()
    self.assertEqual(t.size(), 0)
    t.delete(5)
    self.assertEqual(t.size(), 0)

    # Delete the root node (no child).
    t = BinarySearchTree()
    t.insert_int(3)
    t.delete(3)
    self.assertEqual(t.size(), 0)

    # Delete the root node (successor is left child).
    t = BinarySearchTree()
    t.insert_int(3, 1)
    self.assertEqual(t.size(), 2)
    t.delete(3)
    self.assertEqual(t.size(), 1)
    self.assertEqual(t.root.val, "val=1")

    # Delete the root node (successor is right child).
    t = BinarySearchTree()
    t.insert_int(3, 4)
    self.assertEqual(t.size(), 2)
    t.delete(3)
    self.assertEqual(t.size(), 1)
    self.assertEqual(t.root.val, "val=4")

    # Delete the root node (has 2 children; replacement is successor of right
    # child, but the right child itself is the successor because it has no
    # children on the left).
    t = BinarySearchTree()
    t.insert_int(3, 4, 1, 5, 2)
    self.assertEqual(t.size(), 5)
    self.assertEqual(t.root.val, "val=3")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.left.right.val, "val=2")
    self.assertEqual(t.root.right.val, "val=4")
    self.assertEqual(t.root.right.right.val, "val=5")
    t.delete(3)
    self.assertEqual(t.size(), 4)
    self.assertEqual(t.root.val, "val=4")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.left.right.val, "val=2")
    self.assertEqual(t.root.right.val, "val=5")

    t = BinarySearchTree()
    t.insert_int(3, 5, 1, 4, 2, 6)
    self.assertEqual(t.size(), 6)
    self.assertEqual(t.root.val, "val=3")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.left.right.val, "val=2")
    self.assertEqual(t.root.right.val, "val=5")
    self.assertEqual(t.root.right.left.val, "val=4")
    self.assertEqual(t.root.right.right.val, "val=6")
    t.delete(5)
    self.assertEqual(t.size(), 5)
    self.assertEqual(t.root.val, "val=3")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.left.left, None)
    self.assertEqual(t.root.left.right.val, "val=2")
    self.assertEqual(t.root.right.val, "val=6")
    self.assertEqual(t.root.right.left.val, "val=4")
    self.assertEqual(t.root.right.right, None)

    t = BinarySearchTree()
    t.insert_int(2, 5, 7, 4, 3, 9, 1, 6)
    self.assertEqual(t.size(), 8)
    self.assertEqual(t.root.val, "val=2")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.right.val, "val=5")
    self.assertEqual(t.root.right.left.val, "val=4")
    self.assertEqual(t.root.right.left.left.val, "val=3")
    self.assertEqual(t.root.right.right.val, "val=7")
    self.assertEqual(t.root.right.right.left.val, "val=6")
    self.assertEqual(t.root.right.right.right.val, "val=9")
    t.delete(5)
    self.assertEqual(t.size(), 7)
    self.assertEqual(t.root.val, "val=2")
    self.assertEqual(t.root.left.val, "val=1")
    self.assertEqual(t.root.right.val, "val=6")
    self.assertEqual(t.root.right.left.val, "val=4")
    self.assertEqual(t.root.right.left.left.val, "val=3")
    self.assertEqual(t.root.right.right.val, "val=7")
    self.assertEqual(t.root.right.right.right.val, "val=9")

4.4.1. Property based tests

Randomly generate keys to insert, and randomly delete some of them.

@given(st.lists(st.integers(min_value=0, max_value=64),
                min_size=0,
                max_size=64),
       st.lists(st.integers(min_value=0, max_value=32),
                min_size=0,
                max_size=32))
def test_delete_random_at_random(self, keys: list[int], to_delete: list[int]):
    keys_unique = list(dict.fromkeys(keys))
    to_delete_unique = list(dict.fromkeys(to_delete))
    starting_size = len(keys_unique)
    t = BinarySearchTree()
    t.insert_int(*keys)
    self.assertEqual(t.size(), starting_size)

    # Repeatedly delete from a random node in the tree.
    deleted_count = 0
    for delete_me in to_delete_unique:
        size_before = t.size()
        t.delete(delete_me)
        size_after = t.size()
        # Only count the deletion if we actually deleted a node. The presumption
        # here is that t.delete() only deletes a single node if it does delete
        # it (that it doesn't grow the tree, for example).
        if size_before != size_after:
            deleted_count += 1
    self.assertEqual(t.size(), starting_size - deleted_count)

If we delete every node, then the tree should have size 0.

@given(st.lists(st.integers(min_value=0, max_value=32),
                min_size=0,
                max_size=32))
def test_delete_random_tree_drain_all_keys(self, keys: list[int]):
    keys_unique = list(dict.fromkeys(keys))
    t = BinarySearchTree()
    t.insert_int(*keys)
    self.assertEqual(t.size(), len(keys_unique))

    for key in keys:
        t.delete(key)
    self.assertEqual(t.size(), 0)

4.5. Traversal

For these traversals, we construct the following binary tree (the keys are integers and the values are just the string representations of the keys, and so these redundant values are omitted from the illustration)

binary_search_tree.svg

def test_traversal(self):
    traversal_history = []
    def record_traversal_history(x: Node):
        traversal_history.append(x.key)

    t = BinarySearchTree()
    t.insert_int(5, 1, 9, 4, 7, 2, 10, 0)

    t.traverse_preorder(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 0, 4, 2, 9, 7, 10])

    traversal_history = []
    t.traverse_inorder(record_traversal_history)
    self.assertEqual(traversal_history, [0, 1, 2, 4, 5, 7, 9, 10])

    traversal_history = []
    t.traverse_postorder(record_traversal_history)
    self.assertEqual(traversal_history, [0, 2, 4, 1, 7, 10, 9, 5])

    traversal_history = []
    t.bfs(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2])

    traversal_history = []
    t.bfs_single_pass(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2])

4.5.1. Property based tests

But when we do an in-order traversal, we should still get back out the elements in order, regardless of what keys were inserted in whichever order.

@given(st.lists(st.integers(min_value=0, max_value=32),
                min_size=0,
                max_size=32))
def test_delete_random_keys_are_always_sorted(self, keys: list[int]):
    t = BinarySearchTree()
    t.insert_int(*keys)

    traversal_history = []
    def record_traversal_history(x: Node):
        traversal_history.append(x.key)
    t.traverse_inorder(record_traversal_history)
    sorted_traversal_history = list(sorted(traversal_history))
    self.assertEqual(traversal_history, sorted_traversal_history)

5. Export

from __future__ import annotations
from typing import Any, Callable, Optional
code

6. References

Sedgewick, R., & Wayne, K. D. (2011). Algorithms (4th ed). Addison-Wesley.