Binary trees

1. Problem statement

Implement a binary tree (Aziz et al., 2018, p. 123).

2. Insights

2.1. Similarity to linked lists

A binary tree is eerily similar to a linked list; whereas a node in the linked list can only point to a single "next" node, a node in a binary tree points to two other nodes, called the "left" and "right" children (or subtrees). Indeed, if a binary tree only has left nodes (or "right" nodes), then the structure devolves to a linked list.

As we will see, even though they are "similar" in the above sense, they have wildly different performance properties and require much more complicated algorithms than linked lists.

2.2. Relation to binary search trees (BSTs)

Binary search trees, or BSTs, are binary trees with additional ordering guarantees. Whereas BSTs are all about maintaining this ordering guarantee (basically that all children on the left node are smaller than all children on the right node), plain binary trees are concerned more with the actual structure (or position) of the nodes in relation to each other. This latter point is important because whereas there can be multiple different ways of representing the same data with BSTs (depending on the order in which the values are inserted into the BST), with plain binary trees the structure itself matters.

3. Solution

We could define Node as a nested class inside the BinaryTree class, but for simplicity we choose not to. In a "production" implementation, we would have to concern ourselves with API boundaries and encapsulation, but those concerns are beyond the scope of our discussion. And so we choose to ignore them here.

We use a two-class scheme. The Node class stores the links to the children (and additional metadata about the node and its children). The BinaryTree class has all of the interesting methods that act on the overall tree, because it's expected that users will never have to deal with the Node class directly.

from collections import deque

class Node:
    node_class_methods
class BinaryTree:
    binary_tree_class_methods

3.1. Initialization

3.1.1. Node initialization

We set the left and right links to None. These point to other Node instances, if any. Later when we add more nodes into this tree, these links will become populated.

We also have the val field to actually store the value we want inside this node, as well as the count field to tell us how large the binary tree is (if we tree the current node as the root).

def __init__(self, val: Any=None):
    self.val = val
    self.left: Optional[Node] = None
    self.right: Optional[Node] = None
    self.count = 1

3.1.2. Tree initialization

For creating the tree, we want to just make sure that there is either a single node, or nothing (an empty tree).

def __init__(self, val: Any=None):
    if val is None:
        self.root = None
    else:
        self.root = Node(val)

3.2. Insertion

For each value we want to insert, take a list of directions. These directions are a list of 0's or 1's and tell us to go either left or right (left is 0, right is 1), starting from the root node. When we run out of directions, insert the value at that navigated-to node.

If there are any nodes missing on the way to the final destination, create nodes along the way. This way, we don't force our callers to only create new nodes exactly at a particular leaf node (they could choose to, for example, create the leaf node first, followed by filling up the parent nodes).

The main workhorse is _insert(), a private method (insert() is meant for external use).

def insert(self, val: Any, directions: List[int]):
    self.root = self._insert(self.root, val, directions)

We navigate down the tree if there are directions. As we do so, we make note of all the nodes we've seen, adding them to hist.

def _insert(self, x: Optional[Node], val: Any, directions: List[int]):
    if x is None:
        x = Node()
    start = x
    hist = deque([(x, False)])
    for direction in directions:
        new_node = False
        if direction:
            if x.right is None:
                x.right = Node()
                new_node = True
            x = x.right
        else:
            if x.left is None:
                x.left = Node()
                new_node = True
            x = x.left
        hist.append((x, new_node))

    # Update counts back up the tree (by visiting seen nodes) by processing
    # hist. We skip the last node because it already has a count of 1 by
    # default.
    delta = 0
    while hist:
        node, new_node = hist.pop()
        if new_node:
            delta += 1
            if delta == 1:
                continue
        node.count += delta

    # Finally, set the value on the node.
    x.val = val
    return start

Note that it is important that we return the initial node (start) we've started from. This way we can assign it back into self.root in insert().

3.3. Size

The size of the tree is the size of all the nodes in the tree, including its children. We update the count field every time we insert new nodes, including the counts for all parents. As we shall see later, we also update the count for all relevant nodes whenever we perform a deletion. This is why we don't have to do any extra work when trying to figure out the size of the tree — it's just a direct lookup of the x.count field itself.

def size(self):
    return self._size(self.root)
def _size(self, x: Optional[Node]):
    if x is None:
        return 0
    return x.count

3.4. Lookup

Looking up a node is similar to an insertion operation, because we have to navigate to the node by following the provided directions. However, if there is an error at any point in navigating to the desired node, we raise an exception instead. From the caller's point of view, a value is only returned as a successful lookup if we were able to navigate to the desired node by following the provided directions.

In some sense this is a lot like a pointer defererence operation: we are merely checking that we aren't "dereferencing" (looking up) an invalid "address" (directions to get to the node).

def lookup(self, directions: List[int]):
    return self._lookup(self.root, directions)
def _lookup(self, x: Optional[Node], directions: List[int]):
    if x is None:
        raise ValueError("no root node to start navigation from")
    for direction in directions:
        if x is None:
            raise ValueError("could not navigate to node")
        if direction:
            x = x.right
        else:
            x = x.left
    if x is None:
        raise ValueError("navigated to a None type, not a node")
    return x.val

3.5. Deletion

Deletion is tricky. Not only do we have to update the counts, but we could be moving nodes around in order to preserve the overall structure of the tree. Let's think about preserving the structure first (and see why this would even be a concern).

If we want to delete a node in the tree, we must consider 3 possible cases:

  1. the to-be-deleted node has 0 children (leaf node),
  2. the to-be-deleted node has 1 child, or
  3. the to-be-deleted node has 2 children.

The first case is easy — from its parent's point of view, we just set the corresponding left or right link (pointer) to it to None.

The second case requires us to promote the only child of the node to be deleted to be the replacement. It's just like deleting a node in a linked list (where the "next" node becomes the current node).

The third case is tricky. We have two child nodes (and they may also have children of their own), but we can only assign one child to be the replacement. Obviously we cannot make both children the replacement because then it won't be a binary tree any more.

Here we just do the simplest thing possible — we recurse down the left subtrees until we are at a leaf node, and promote this one to be the replacement. This we way don't have to affect the structure of the other children that much. We could also recurse down the right subtree to get the rightmost leaf node, or perhaps even pick a leaf node at random — but choosing the leftmost child seems like the simplest thing and is what we do in _delete().

def _delete(self, x: Optional[Node], directions: List[int]) -> Optional[Node]:
    # If the starting node is not a node, we cannot do any navigation on it to
    # find the node we want to delete. Abort.
    if x is None:
        return None

    # Find the node to delete. We recurse down here, but the main point is that
    # only the last call to _delete() will do the actual deletion. All calls to
    # _delete() leading up to the last one just perform navigation and
    # assignment back up the hierarchy.
    if directions:
        if directions[0]:
            x.right = self._delete(x.right, directions[1:])
        else:
            x.left = self._delete(x.left, directions[1:])
        x.count = self._size(x.left) + self._size(x.right) + 1
        return x

    # We've arrived at the node we want to delete. We deal with the 3 possible
    # cases.

    # If there are two children, we have to pick a replacement from the
    # leftmost root node. Picking a replacement is a bit tricky because the
    # current parent of the replacement needs to no longer point to it
    # (otherwise we'll end up just duplicating the replacement).
    if x.left is not None and x.right is not None:
        y = x
        # Overwrite x, "deleting" it by replacing it with its leftmost leaf
        # node. But make this leaf node a non-leaf node by making it point to
        # x's children (saved in y).
        x = self._pop_leftmost_leaf_node(x.left)
        # Only assign the previous left child as the replacement's left child if
        # the replacement is not itself the left child. I.e., avoid a possible
        # cycle.
        if x != y.left:
            x.left = y.left
        x.right = y.right
        # We have 1 less node than before. We can't use x.count because x is the
        # leaf node (and its count is 1).
        x.count = y.count - 1
    # If there is just one child, then make that one the replacement.
    elif x.left is not None:
        x = x.left
    elif x.right is not None:
        x = x.right
    # No children. Easy case where we just assign None to self (we are
    # deleting this BinaryTree object, essentially).
    else:
        return None

    return x

Now it may very well be the case that getting the leftmost leaf node results in getting the node that we started searching with. That is, if we have 2 children and the start the "leftmost leaf node" search on the left child, it may be that that child is already the leaf node. In that case the while loop would break immediately and we would return the argument we got (x) as is.

def _pop_leftmost_leaf_node(self, x: Node) -> Node:
    hist = []
    went_left = False
    while True:
        if x.left is not None:
            hist.append(x)
            x = x.left
            went_left = True
            continue
        if x.right is not None:
            hist.append(x)
            x = x.right
            went_left = False
            continue
        break

    # If we had any children (descended down the tree), we have to modify it so
    # that the parent of the leaf node is no longer pointing to this leaf node.
    # Otherwise, the leaf node would still stay in the tree and would not be
    # "popped" out of it.
    if hist:
        if went_left:
            hist[-1].left = None
        else:
            hist[-1].right = None
    for h in hist:
        h.count -= 1

    return x

As for the public method delete(), it first tries to do a lookup of whether the node to be deleted can be navigated-to. That is, it checks whether the node we want to delete even exists. If it does not exist, then we abort and do nothing, returning None.

def delete(self, directions: Optional[List[int]]=None):
    if directions is None:
        directions = []
    to_delete = None

    try:
        to_delete = self.lookup(directions)
    except ValueError:
        # The directions are either bogus, or at best point us to a None (not a
        # node). In both cases there is nothing more to do.
        return None
    else:
        self.root = self._delete(self.root, directions)
        # Return the node we wanted to delete. Because the operation above
        # completed successfully, it is guaranteed to have deleted it from the
        # tree (removed any link to this object from the tree, such that this
        # node lies outside of the realm of the tree).
        return to_delete

3.6. Traversal

Traversing a binary tree means visiting (performing some action on) each node in the tree exactly once.

For the different kinds of traversals, a simple way to remember is the timing of acting on the current node. If the node is acted upon first (before acting on the child nodes), it is called pre-order traversal. If it is acted on last (after acting on the child nodes), it is called post-order traversal. If it is acted on in-between the left and right children, it's called in-order traversal.

def traverse_preorder(self, func: Callable[[Node], None]):
    return self._traverse_preorder(self.root, func)

def _traverse_preorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    func(x)
    self._traverse_preorder(x.left, func)
    self._traverse_preorder(x.right, func)

def traverse_inorder(self, func: Callable[[Node], None]):
    return self._traverse_inorder(self.root, func)

def _traverse_inorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    self._traverse_inorder(x.left, func)
    func(x)
    self._traverse_inorder(x.right, func)

def traverse_postorder(self, func: Callable[[Node], None]):
    return self._traverse_postorder(self.root, func)

def _traverse_postorder(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return
    self._traverse_postorder(x.left, func)
    self._traverse_postorder(x.right, func)
    func(x)

"In-order" traversal is named as such, because if the tree is a binary search tree (such that all values in the left subtree are less than or equal to the current node's value, and all values in the right subtree are greater than the current node's value), then we end up visiting the nodes in sorted order.

As all trees (binary ones included) are graphs, depth-first search (DFS) and breadth-first search (BFS) apply here as well. It turns out that the three types of traversal we covered above all fall into the category of DFS, because they are all concerned with visiting children recursively (and each recursive call is a descent down a level in the tree).

3.6.1. BFS

BFS is also possible. This is also called level-order traversal, because we visit all children at a level in the tree first before moving down to their children. Unlike DFS, BFS requires an auxiliary data structure to keep track of which nodes to visit, because we need a way to visit the nodes beyond the left and right links included natively in each binary tree node.

def bfs(self, func: Callable[[Node], None]):
    return self._bfs(self.root, func)

def _bfs(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return

    nodes_at_current_depth = [x]
    while nodes_at_current_depth:
        # Process all nodes at current depth.
        for node in nodes_at_current_depth:
            func(node)

        # Now add all nodes at the next depth.
        children = []
        for node in nodes_at_current_depth:
            if not node:
                continue
            if node.left:
                children.append(node.left)
            if node.right:
                children.append(node.right)

        # Repeat the loop at the next depth.
        nodes_at_current_depth = children

The version above needs to loop through nodes_at_current_depth twice in each outer while loop iteration. The version below only uses a single for loop inside the while loop, at the cost of needing to rename the nodes_at_current_depth variable to q (it is no longer only holding nodes at the current level, but instead at the current and next level as it gets mutated in-place).

def bfs_single_pass(self, func: Callable[[Node], None]):
    return self._bfs(self.root, func)

def _bfs_single_pass(self, x: Optional[Node], func: Callable[[Node], None]):
    if x is None:
        return

    q = deque([x])
    while q:
        # Process the head of the queue. As we process each one, just before we
        # discard it we check if it has children, and if so, add them to the end
        # of the queue.

        node = q.popleft()
        if not node:
            continue

        func(node)

        # Add this node's children, if any.
        if node.left:
            q.append(node.left)
        if node.right:
            q.append(node.right)

4. Tests

🔗
from __future__ import annotations
from hypothesis import given, strategies as st
import unittest

from .binary_tree import BinaryTree, Node

class Test(unittest.TestCase):
    test_cases

if __name__ == "__main__":
    unittest.main(exit=False)

4.1. Initialization

def test_init(self):
    # Initializing with no arguments results in a tree with no nodes.
    t = BinaryTree()
    self.assertEqual(t.root, None)
    self.assertEqual(t.size(), 0)

    # Initializing with "None" as an argument is basically the same thing.
    t = BinaryTree(None)
    self.assertEqual(t.root, None)
    self.assertEqual(t.size(), 0)

    # Initializing with an argument is allowed but we can only populate a single
    # BinaryTree (if we want to create more nodes, we have to use insert()).
    t = BinaryTree(1)
    self.assertEqual(t.root.val, 1)
    self.assertEqual(t.size(), 1)

4.2. Insertion

binary_tree_test_insert.svg

def test_insert(self):
    t = BinaryTree(1)
    self.assertEqual(t.size(), 1)
    t.insert(2, [0])
    self.assertEqual(t.size(), 2)
    t.insert(3, [1])
    self.assertEqual(t.size(), 3)
    t.insert(4, [0, 0])
    self.assertEqual(t.size(), 4)
    # Inserting at the same location does not update the count.
    t.insert(50, [0, 0])
    self.assertEqual(t.size(), 4)

    # Check sizes at every child node as well.
    self.assertEqual(t.root.left.count, 2)
    self.assertEqual(t.root.right.count, 1)
    self.assertEqual(t.root.left.left.count, 1)

4.3. Lookup

def test_lookup(self):
    t = BinaryTree(1)
    t.insert(2, [0])
    t.insert(3, [1])
    t.insert(4, [0, 0])
    self.assertEqual(t.lookup([]), 1)
    self.assertEqual(t.lookup([0]), 2)
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.lookup([0, 0]), 4)

4.4. Deletion

def test_delete(self):
    t = BinaryTree()
    self.assertRaises(ValueError, t.lookup, [])
    self.assertEqual(t.size(), 0)

    t = self.make_complete_tree(3)
    self.assertEqual(t.lookup([]), 1)
    self.assertEqual(t.lookup([0]), 2)
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.size(), 3)

    t = self.make_complete_tree(1)
    self.assertEqual(t.lookup([]), 1)
    deleted = t.delete()
    self.assertEqual(deleted, 1)
    self.assertRaises(ValueError, t.lookup, [])
    self.assertEqual(t.size(), 0)

    # Deleting the child of a leaf node is a NOP (there's nothing to delete).
    t = self.make_complete_tree(3)
    self.assertEqual(t.lookup([]), 1)
    self.assertEqual(t.lookup([0]), 2)
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.size(), 3)
    deleted = t.delete(self.tree_path(1))
    self.assertEqual(deleted, 1)
    self.assertEqual(t.lookup([]), 2)
    self.assertRaises(ValueError, t.lookup, [0])
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.size(), 2)
    # Deleting the None object at [0] results in a NOP.
    self.assertEqual(self.tree_path(2), [0])
    t.delete(self.tree_path(2))
    self.assertEqual(t.size(), 2)

    t = self.make_complete_tree(3)
    self.assertEqual(t.size(), 3)
    t.delete([]) # Delete root node.
    self.assertEqual(t.lookup([]), 2)
    self.assertRaises(ValueError, t.lookup, [0])
    # 3 is 2's child now.
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.size(), 2)

    # Delete non-root, leaf node (no children).
    t = self.make_complete_tree(3)
    self.assertEqual(t.lookup([]), 1)
    self.assertEqual(t.size(), 3)
    self.assertEqual(t.root.val, 1)
    self.assertEqual(t.lookup([0]), 2)
    t.delete([0])
    self.assertRaises(ValueError, t.lookup, [0])
    self.assertEqual(t.size(), 2)
    self.assertEqual(t.lookup([1]), 3)
    t.delete([1])
    self.assertRaises(ValueError, t.lookup, [1])
    self.assertEqual(t.size(), 1)

    # Delete non-root node (1 child). We want to delete 2, when 4 is its child.
    t = self.make_complete_tree(4)
    self.assertEqual(t.size(), 4)
    self.assertEqual(t.lookup([0]), 2)
    self.assertEqual(t.lookup([0, 0]), 4)
    t.delete([0]) # Delete 2, making its only child 4 its successor.
    self.assertEqual(t.size(), 3)
    self.assertEqual(t.lookup([0]), 4)
    t.delete([0]) # Delete the successor.
    self.assertEqual(t.size(), 2)
    self.assertRaises(ValueError, t.lookup, [0])

    # Same as above, but the successor is the right child.
    t = self.make_complete_tree(6)
    self.assertEqual(t.size(), 6)
    self.assertEqual(t.lookup([0]), 2)
    self.assertEqual(t.lookup([0, 0]), 4)
    self.assertEqual(t.lookup([0, 1]), 5)
    t.delete([0, 0])
    # Delete 2, making its only child 5 its successor (this time its child on
    # the right side).
    t.delete([0])
    self.assertEqual(t.lookup([0]), 5)
    self.assertEqual(t.size(), 4)

    # Delete non-root node which has 2 children. Expect its leftmost leaf node
    # to be its successor.
    t = self.make_complete_tree(8)
    self.assertEqual(t.size(), 8)
    # Delete 2. This makes 8 the leftmost leaf node, making it its successor.
    t.delete([0])
    self.assertEqual(t.lookup([0]), 8)
    self.assertEqual(t.size(), 7)
    # The other nodes are untouched.
    self.assertEqual(t.lookup([]), 1)
    self.assertEqual(t.lookup([0, 0]), 4)
    self.assertEqual(t.lookup([0, 1]), 5)
    self.assertEqual(t.lookup([1]), 3)
    self.assertEqual(t.lookup([1, 0]), 6)
    self.assertEqual(t.lookup([1, 1]), 7)

    # Successor node is several levels down the tree.
    t = BinaryTree(10)
    t.insert(3, [1])
    t.insert(8, [1, 0])
    t.insert(9, [1, 0, 1])
    t.insert(30, [1, 0, 1, 0]) # Successor (quite far down).
    t.insert(17, [1, 1])
    t.insert(15, [1, 1, 1])
    t.insert(55, [1, 1, 1, 0])
    self.assertEqual(t.size(), 8)
    t.delete([1])
    self.assertEqual(t.size(), 7)
    self.assertEqual(t.lookup([1]), 30) # 30 is the successor.
    self.assertEqual(t.lookup([1, 1]), 17)
    self.assertEqual(t.lookup([1, 1, 1]), 15)
    self.assertEqual(t.lookup([1, 1, 1, 0]), 55)
    self.assertEqual(t.lookup([1, 0]), 8)
    self.assertEqual(t.lookup([1, 0, 1]), 9)
    self.assertRaises(ValueError, t.lookup, [1, 0, 1, 0])
    self.assertRaises(ValueError, t.lookup, [1, 0, 1, 1])

4.5. Traversal

For these traversals, we construct the following binary tree (the keys are integers and the values are just the string representations of the keys, and so these redundant values are omitted from the illustration)

binary_tree.svg

def test_traversal(self):
    traversal_history = []
    def record_traversal_history(x: Node):
        traversal_history.append(x.val)

    t = BinaryTree()
    t.insert(5, [])
    t.insert(1, [0])
    t.insert(9, [1])
    t.insert(0, [0, 0])
    t.insert(4, [0, 1])
    t.insert(2, [0, 1, 0])
    t.insert(7, [1, 0])
    t.insert(10, [1, 1])

    t.traverse_preorder(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 0, 4, 2, 9, 7, 10])

    traversal_history = []
    t.traverse_inorder(record_traversal_history)
    self.assertEqual(traversal_history, [0, 1, 2, 4, 5, 7, 9, 10])

    traversal_history = []
    t.traverse_postorder(record_traversal_history)
    self.assertEqual(traversal_history, [0, 2, 4, 1, 7, 10, 9, 5])

    traversal_history = []
    t.bfs(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2])

    traversal_history = []
    t.bfs_single_pass(record_traversal_history)
    self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2])

4.6. Property based tests

For making test data more uniform, we define make_complete_tree() to create complete binary trees. A complete binary tree grows by filling out nodes at every level (left to right) before running out of room and filling in nodes at the next level (again left to right). Perfect trees have nodes at every level. For example, a perfect tree of depth 3 (which has \(2^3 - 1 = 7\) nodes) looks like this:

binary_tree_perfect.svg

# Helper function to create complete trees of a given size.
def make_complete_tree(self, size: int):
    t = BinaryTree()
    for n in range(1, size + 1):
        t.insert(n, self.tree_path(n))
    return t

# Convert a number into into binary form, but as a list of binary digits
# (instead of a string).
def bin_digits(self, n: int) -> list[int]:
    return [int(c) for c in str(bin(n))[2:]]

# Return a path like "[1, 0, 1, ...]" for a node name in a perfect tree.
def tree_path(self, n: int):
    # 1 is the root node.
    if n <= 1:
        n = 1
    return self.bin_digits(n)[1:]

Below we create a tree and perform a random number of deletions, all at the root node.

@given(st.integers(min_value=1, max_value=30))
def test_delete_random_at_root(self, deletions: int):
    starting_size = 31
    t = self.make_complete_tree(starting_size)
    # Repeatedly delete from the root node.
    for _ in range(deletions):
        t.delete([])
    self.assertEqual(t.size(), starting_size - deletions)

The following test is similar, but we delete at random points in the tree.

@given(st.lists(st.integers(min_value=0, max_value=15),
                min_size=0,
                max_size=15))
def test_delete_random_at_random(self, deletion_paths: list[int]):
    starting_size = 15
    t = self.make_complete_tree(starting_size)
    self.assertEqual(t.size(), 15)

    # Repeatedly delete from a random node in the tree.
    deleted_count = 0
    for deletion_path in deletion_paths:
        path = self.tree_path(deletion_path)
        deleted_node = t.delete(path)
        # Only count the deletion if we actually deleted a node (maybe the path
        # was pointing to nothing).
        if deleted_node is not None:
            deleted_count += 1
    self.assertEqual(t.size(), starting_size - deleted_count)

5. Export

from __future__ import annotations
from typing import Any, Callable, Optional, List
code

6. References

Aziz, A., Lee, T.-H., & Prakash, A. (2018). Elements of Programming Interviews in Python: The Insiders’ Guide. CreateSpace Independent Publishing Platform (25 July. 2018).