Binary trees
1. Problem statement
Implement a binary tree (Aziz et al., 2018, p. 123).
2. Insights
2.1. Similarity to linked lists
A binary tree is eerily similar to a linked list; whereas a node in the linked list can only point to a single "next" node, a node in a binary tree points to two other nodes, called the "left" and "right" children (or subtrees). Indeed, if a binary tree only has left nodes (or "right" nodes), then the structure devolves to a linked list.
As we will see, even though they are "similar" in the above sense, they have wildly different performance properties and require much more complicated algorithms than linked lists.
2.2. Relation to binary search trees (BSTs)
Binary search trees, or BSTs, are binary trees with additional ordering guarantees. Whereas BSTs are all about maintaining this ordering guarantee (basically that all children on the left node are smaller than all children on the right node), plain binary trees are concerned more with the actual structure (or position) of the nodes in relation to each other. This latter point is important because whereas there can be multiple different ways of representing the same data with BSTs (depending on the order in which the values are inserted into the BST), with plain binary trees the structure itself matters.
3. Solution
We could define Node
as a nested class inside the BinaryTree
class, but for
simplicity we choose not to. In a "production" implementation, we would have to
concern ourselves with API boundaries and encapsulation, but those concerns are
beyond the scope of our discussion. And so we choose to ignore them here.
We use a two-class scheme. The Node
class stores the links to the children
(and additional metadata about the node and its children). The BinaryTree
class has all of the interesting methods that act on the overall tree, because
it's expected that users will never have to deal with the Node
class directly.
from collections import deque class Node: node_class_methods class BinaryTree: binary_tree_class_methods
3.1. Initialization
3.1.1. Node initialization
We set the left and right links to None
. These point to other Node
instances, if any. Later when we add more nodes into this tree, these links will
become populated.
We also have the val
field to actually store the value we want inside this
node, as well as the count
field to tell us how large the binary tree is (if
we tree the current node as the root).
def __init__(self, val: Any=None): self.val = val self.left: Optional[Node] = None self.right: Optional[Node] = None self.count = 1
3.1.2. Tree initialization
For creating the tree, we want to just make sure that there is either a single node, or nothing (an empty tree).
def __init__(self, val: Any=None): if val is None: self.root = None else: self.root = Node(val)
3.2. Insertion
For each value we want to insert, take a list of directions. These directions are a list of 0's or 1's and tell us to go either left or right (left is 0, right is 1), starting from the root node. When we run out of directions, insert the value at that navigated-to node.
If there are any nodes missing on the way to the final destination, create nodes along the way. This way, we don't force our callers to only create new nodes exactly at a particular leaf node (they could choose to, for example, create the leaf node first, followed by filling up the parent nodes).
The main workhorse is _insert()
, a private method (insert()
is meant for
external use).
def insert(self, val: Any, directions: List[int]): self.root = self._insert(self.root, val, directions)
We navigate down the tree if there are directions. As we do so, we make note of
all the nodes we've seen, adding them to hist
.
def _insert(self, x: Optional[Node], val: Any, directions: List[int]): if x is None: x = Node() start = x hist = deque([(x, False)]) for direction in directions: new_node = False if direction: if x.right is None: x.right = Node() new_node = True x = x.right else: if x.left is None: x.left = Node() new_node = True x = x.left hist.append((x, new_node)) # Update counts back up the tree (by visiting seen nodes) by processing # hist. We skip the last node because it already has a count of 1 by # default. delta = 0 while hist: node, new_node = hist.pop() if new_node: delta += 1 if delta == 1: continue node.count += delta # Finally, set the value on the node. x.val = val return start
Note that it is important that we return the initial node (start
) we've
started from. This way we can assign it back into self.root
in insert()
.
3.3. Size
The size of the tree is the size of all the nodes in the tree, including its
children. We update the count
field every time we insert new nodes, including
the counts for all parents. As we shall see later, we also update the count
for all relevant nodes whenever we perform a deletion. This is why we don't
have to do any extra work when trying to figure out the size of the tree —
it's just a direct lookup of the x.count
field itself.
def size(self): return self._size(self.root) def _size(self, x: Optional[Node]): if x is None: return 0 return x.count
3.4. Lookup
Looking up a node is similar to an insertion operation, because we have to navigate to the node by following the provided directions. However, if there is an error at any point in navigating to the desired node, we raise an exception instead. From the caller's point of view, a value is only returned as a successful lookup if we were able to navigate to the desired node by following the provided directions.
In some sense this is a lot like a pointer defererence operation: we are merely checking that we aren't "dereferencing" (looking up) an invalid "address" (directions to get to the node).
def lookup(self, directions: List[int]): return self._lookup(self.root, directions) def _lookup(self, x: Optional[Node], directions: List[int]): if x is None: raise ValueError("no root node to start navigation from") for direction in directions: if x is None: raise ValueError("could not navigate to node") if direction: x = x.right else: x = x.left if x is None: raise ValueError("navigated to a None type, not a node") return x.val
3.5. Deletion
Deletion is tricky. Not only do we have to update the counts, but we could be moving nodes around in order to preserve the overall structure of the tree. Let's think about preserving the structure first (and see why this would even be a concern).
If we want to delete a node in the tree, we must consider 3 possible cases:
- the to-be-deleted node has 0 children (leaf node),
- the to-be-deleted node has 1 child, or
- the to-be-deleted node has 2 children.
The first case is easy — from its parent's point of view, we just set the
corresponding left or right link (pointer) to it to None
.
The second case requires us to promote the only child of the node to be deleted to be the replacement. It's just like deleting a node in a linked list (where the "next" node becomes the current node).
The third case is tricky. We have two child nodes (and they may also have children of their own), but we can only assign one child to be the replacement. Obviously we cannot make both children the replacement because then it won't be a binary tree any more.
Here we just do the simplest thing possible — we recurse down the left
subtrees until we are at a leaf node, and promote this one to be the
replacement. This we way don't have to affect the structure of the other
children that much. We could also recurse down the right subtree to get the
rightmost leaf node, or perhaps even pick a leaf node at random — but choosing
the leftmost child seems like the simplest thing and is what we do in
_delete()
.
def _delete(self, x: Optional[Node], directions: List[int]) -> Optional[Node]: # If the starting node is not a node, we cannot do any navigation on it to # find the node we want to delete. Abort. if x is None: return None # Find the node to delete. We recurse down here, but the main point is that # only the last call to _delete() will do the actual deletion. All calls to # _delete() leading up to the last one just perform navigation and # assignment back up the hierarchy. if directions: if directions[0]: x.right = self._delete(x.right, directions[1:]) else: x.left = self._delete(x.left, directions[1:]) x.count = self._size(x.left) + self._size(x.right) + 1 return x # We've arrived at the node we want to delete. We deal with the 3 possible # cases. # If there are two children, we have to pick a replacement from the # leftmost root node. Picking a replacement is a bit tricky because the # current parent of the replacement needs to no longer point to it # (otherwise we'll end up just duplicating the replacement). if x.left is not None and x.right is not None: y = x # Overwrite x, "deleting" it by replacing it with its leftmost leaf # node. But make this leaf node a non-leaf node by making it point to # x's children (saved in y). x = self._pop_leftmost_leaf_node(x.left) # Only assign the previous left child as the replacement's left child if # the replacement is not itself the left child. I.e., avoid a possible # cycle. if x != y.left: x.left = y.left x.right = y.right # We have 1 less node than before. We can't use x.count because x is the # leaf node (and its count is 1). x.count = y.count - 1 # If there is just one child, then make that one the replacement. elif x.left is not None: x = x.left elif x.right is not None: x = x.right # No children. Easy case where we just assign None to self (we are # deleting this BinaryTree object, essentially). else: return None return x
Now it may very well be the case that getting the leftmost leaf node results in
getting the node that we started searching with. That is, if we have 2 children
and the start the "leftmost leaf node" search on the left child, it may be that
that child is already the leaf node. In that case the while
loop would break
immediately and we would return the argument we got (x
) as is.
def _pop_leftmost_leaf_node(self, x: Node) -> Node: hist = [] went_left = False while True: if x.left is not None: hist.append(x) x = x.left went_left = True continue if x.right is not None: hist.append(x) x = x.right went_left = False continue break # If we had any children (descended down the tree), we have to modify it so # that the parent of the leaf node is no longer pointing to this leaf node. # Otherwise, the leaf node would still stay in the tree and would not be # "popped" out of it. if hist: if went_left: hist[-1].left = None else: hist[-1].right = None for h in hist: h.count -= 1 return x
As for the public method delete()
, it first tries to do a lookup of whether
the node to be deleted can be navigated-to. That is, it checks whether the node
we want to delete even exists. If it does not exist, then we abort and do
nothing, returning None
.
def delete(self, directions: Optional[List[int]]=None): if directions is None: directions = [] to_delete = None try: to_delete = self.lookup(directions) except ValueError: # The directions are either bogus, or at best point us to a None (not a # node). In both cases there is nothing more to do. return None else: self.root = self._delete(self.root, directions) # Return the node we wanted to delete. Because the operation above # completed successfully, it is guaranteed to have deleted it from the # tree (removed any link to this object from the tree, such that this # node lies outside of the realm of the tree). return to_delete
3.6. Traversal
Traversing a binary tree means visiting (performing some action on) each node in the tree exactly once.
For the different kinds of traversals, a simple way to remember is the timing of acting on the current node. If the node is acted upon first (before acting on the child nodes), it is called pre-order traversal. If it is acted on last (after acting on the child nodes), it is called post-order traversal. If it is acted on in-between the left and right children, it's called in-order traversal.
def traverse_preorder(self, func: Callable[[Node], None]): return self._traverse_preorder(self.root, func) def _traverse_preorder(self, x: Optional[Node], func: Callable[[Node], None]): if x is None: return func(x) self._traverse_preorder(x.left, func) self._traverse_preorder(x.right, func) def traverse_inorder(self, func: Callable[[Node], None]): return self._traverse_inorder(self.root, func) def _traverse_inorder(self, x: Optional[Node], func: Callable[[Node], None]): if x is None: return self._traverse_inorder(x.left, func) func(x) self._traverse_inorder(x.right, func) def traverse_postorder(self, func: Callable[[Node], None]): return self._traverse_postorder(self.root, func) def _traverse_postorder(self, x: Optional[Node], func: Callable[[Node], None]): if x is None: return self._traverse_postorder(x.left, func) self._traverse_postorder(x.right, func) func(x)
"In-order" traversal is named as such, because if the tree is a binary search tree (such that all values in the left subtree are less than or equal to the current node's value, and all values in the right subtree are greater than the current node's value), then we end up visiting the nodes in sorted order.
As all trees (binary ones included) are graphs, depth-first search (DFS) and breadth-first search (BFS) apply here as well. It turns out that the three types of traversal we covered above all fall into the category of DFS, because they are all concerned with visiting children recursively (and each recursive call is a descent down a level in the tree).
3.6.1. BFS
BFS is also possible. This is also called level-order traversal, because we visit all children at a level in the tree first before moving down to their children. Unlike DFS, BFS requires an auxiliary data structure to keep track of which nodes to visit, because we need a way to visit the nodes beyond the left and right links included natively in each binary tree node.
def bfs(self, func: Callable[[Node], None]): return self._bfs(self.root, func) def _bfs(self, x: Optional[Node], func: Callable[[Node], None]): if x is None: return nodes_at_current_depth = [x] while nodes_at_current_depth: # Process all nodes at current depth. for node in nodes_at_current_depth: func(node) # Now add all nodes at the next depth. children = [] for node in nodes_at_current_depth: if not node: continue if node.left: children.append(node.left) if node.right: children.append(node.right) # Repeat the loop at the next depth. nodes_at_current_depth = children
The version above needs to loop through nodes_at_current_depth
twice in each
outer while
loop iteration. The version below only uses a single for
loop
inside the while
loop, at the cost of needing to rename the
nodes_at_current_depth
variable to q
(it is no longer only holding nodes at
the current level, but instead at the current and next level as it gets mutated
in-place).
def bfs_single_pass(self, func: Callable[[Node], None]): return self._bfs(self.root, func) def _bfs_single_pass(self, x: Optional[Node], func: Callable[[Node], None]): if x is None: return q = deque([x]) while q: # Process the head of the queue. As we process each one, just before we # discard it we check if it has children, and if so, add them to the end # of the queue. node = q.popleft() if not node: continue func(node) # Add this node's children, if any. if node.left: q.append(node.left) if node.right: q.append(node.right)
4. Tests
from __future__ import annotations from hypothesis import given, strategies as st import unittest from .binary_tree import BinaryTree, Node class Test(unittest.TestCase): test_cases if __name__ == "__main__": unittest.main(exit=False)
4.1. Initialization
def test_init(self): # Initializing with no arguments results in a tree with no nodes. t = BinaryTree() self.assertEqual(t.root, None) self.assertEqual(t.size(), 0) # Initializing with "None" as an argument is basically the same thing. t = BinaryTree(None) self.assertEqual(t.root, None) self.assertEqual(t.size(), 0) # Initializing with an argument is allowed but we can only populate a single # BinaryTree (if we want to create more nodes, we have to use insert()). t = BinaryTree(1) self.assertEqual(t.root.val, 1) self.assertEqual(t.size(), 1)
4.2. Insertion
def test_insert(self): t = BinaryTree(1) self.assertEqual(t.size(), 1) t.insert(2, [0]) self.assertEqual(t.size(), 2) t.insert(3, [1]) self.assertEqual(t.size(), 3) t.insert(4, [0, 0]) self.assertEqual(t.size(), 4) # Inserting at the same location does not update the count. t.insert(50, [0, 0]) self.assertEqual(t.size(), 4) # Check sizes at every child node as well. self.assertEqual(t.root.left.count, 2) self.assertEqual(t.root.right.count, 1) self.assertEqual(t.root.left.left.count, 1)
4.3. Lookup
def test_lookup(self): t = BinaryTree(1) t.insert(2, [0]) t.insert(3, [1]) t.insert(4, [0, 0]) self.assertEqual(t.lookup([]), 1) self.assertEqual(t.lookup([0]), 2) self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.lookup([0, 0]), 4)
4.4. Deletion
def test_delete(self): t = BinaryTree() self.assertRaises(ValueError, t.lookup, []) self.assertEqual(t.size(), 0) t = self.make_complete_tree(3) self.assertEqual(t.lookup([]), 1) self.assertEqual(t.lookup([0]), 2) self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.size(), 3) t = self.make_complete_tree(1) self.assertEqual(t.lookup([]), 1) deleted = t.delete() self.assertEqual(deleted, 1) self.assertRaises(ValueError, t.lookup, []) self.assertEqual(t.size(), 0) # Deleting the child of a leaf node is a NOP (there's nothing to delete). t = self.make_complete_tree(3) self.assertEqual(t.lookup([]), 1) self.assertEqual(t.lookup([0]), 2) self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.size(), 3) deleted = t.delete(self.tree_path(1)) self.assertEqual(deleted, 1) self.assertEqual(t.lookup([]), 2) self.assertRaises(ValueError, t.lookup, [0]) self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.size(), 2) # Deleting the None object at [0] results in a NOP. self.assertEqual(self.tree_path(2), [0]) t.delete(self.tree_path(2)) self.assertEqual(t.size(), 2) t = self.make_complete_tree(3) self.assertEqual(t.size(), 3) t.delete([]) # Delete root node. self.assertEqual(t.lookup([]), 2) self.assertRaises(ValueError, t.lookup, [0]) # 3 is 2's child now. self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.size(), 2) # Delete non-root, leaf node (no children). t = self.make_complete_tree(3) self.assertEqual(t.lookup([]), 1) self.assertEqual(t.size(), 3) self.assertEqual(t.root.val, 1) self.assertEqual(t.lookup([0]), 2) t.delete([0]) self.assertRaises(ValueError, t.lookup, [0]) self.assertEqual(t.size(), 2) self.assertEqual(t.lookup([1]), 3) t.delete([1]) self.assertRaises(ValueError, t.lookup, [1]) self.assertEqual(t.size(), 1) # Delete non-root node (1 child). We want to delete 2, when 4 is its child. t = self.make_complete_tree(4) self.assertEqual(t.size(), 4) self.assertEqual(t.lookup([0]), 2) self.assertEqual(t.lookup([0, 0]), 4) t.delete([0]) # Delete 2, making its only child 4 its successor. self.assertEqual(t.size(), 3) self.assertEqual(t.lookup([0]), 4) t.delete([0]) # Delete the successor. self.assertEqual(t.size(), 2) self.assertRaises(ValueError, t.lookup, [0]) # Same as above, but the successor is the right child. t = self.make_complete_tree(6) self.assertEqual(t.size(), 6) self.assertEqual(t.lookup([0]), 2) self.assertEqual(t.lookup([0, 0]), 4) self.assertEqual(t.lookup([0, 1]), 5) t.delete([0, 0]) # Delete 2, making its only child 5 its successor (this time its child on # the right side). t.delete([0]) self.assertEqual(t.lookup([0]), 5) self.assertEqual(t.size(), 4) # Delete non-root node which has 2 children. Expect its leftmost leaf node # to be its successor. t = self.make_complete_tree(8) self.assertEqual(t.size(), 8) # Delete 2. This makes 8 the leftmost leaf node, making it its successor. t.delete([0]) self.assertEqual(t.lookup([0]), 8) self.assertEqual(t.size(), 7) # The other nodes are untouched. self.assertEqual(t.lookup([]), 1) self.assertEqual(t.lookup([0, 0]), 4) self.assertEqual(t.lookup([0, 1]), 5) self.assertEqual(t.lookup([1]), 3) self.assertEqual(t.lookup([1, 0]), 6) self.assertEqual(t.lookup([1, 1]), 7) # Successor node is several levels down the tree. t = BinaryTree(10) t.insert(3, [1]) t.insert(8, [1, 0]) t.insert(9, [1, 0, 1]) t.insert(30, [1, 0, 1, 0]) # Successor (quite far down). t.insert(17, [1, 1]) t.insert(15, [1, 1, 1]) t.insert(55, [1, 1, 1, 0]) self.assertEqual(t.size(), 8) t.delete([1]) self.assertEqual(t.size(), 7) self.assertEqual(t.lookup([1]), 30) # 30 is the successor. self.assertEqual(t.lookup([1, 1]), 17) self.assertEqual(t.lookup([1, 1, 1]), 15) self.assertEqual(t.lookup([1, 1, 1, 0]), 55) self.assertEqual(t.lookup([1, 0]), 8) self.assertEqual(t.lookup([1, 0, 1]), 9) self.assertRaises(ValueError, t.lookup, [1, 0, 1, 0]) self.assertRaises(ValueError, t.lookup, [1, 0, 1, 1])
4.5. Traversal
For these traversals, we construct the following binary tree (the keys are integers and the values are just the string representations of the keys, and so these redundant values are omitted from the illustration)
def test_traversal(self): traversal_history = [] def record_traversal_history(x: Node): traversal_history.append(x.val) t = BinaryTree() t.insert(5, []) t.insert(1, [0]) t.insert(9, [1]) t.insert(0, [0, 0]) t.insert(4, [0, 1]) t.insert(2, [0, 1, 0]) t.insert(7, [1, 0]) t.insert(10, [1, 1]) t.traverse_preorder(record_traversal_history) self.assertEqual(traversal_history, [5, 1, 0, 4, 2, 9, 7, 10]) traversal_history = [] t.traverse_inorder(record_traversal_history) self.assertEqual(traversal_history, [0, 1, 2, 4, 5, 7, 9, 10]) traversal_history = [] t.traverse_postorder(record_traversal_history) self.assertEqual(traversal_history, [0, 2, 4, 1, 7, 10, 9, 5]) traversal_history = [] t.bfs(record_traversal_history) self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2]) traversal_history = [] t.bfs_single_pass(record_traversal_history) self.assertEqual(traversal_history, [5, 1, 9, 0, 4, 7, 10, 2])
4.6. Property based tests
For making test data more uniform, we define make_complete_tree()
to create
complete binary trees. A complete binary tree grows by filling out nodes
at every level (left to right) before running out of room and filling in nodes
at the next level (again left to right). Perfect trees have nodes at every
level. For example, a perfect tree of depth 3 (which has \(2^3 - 1 = 7\) nodes)
looks like this:
# Helper function to create complete trees of a given size. def make_complete_tree(self, size: int): t = BinaryTree() for n in range(1, size + 1): t.insert(n, self.tree_path(n)) return t # Convert a number into into binary form, but as a list of binary digits # (instead of a string). def bin_digits(self, n: int) -> list[int]: return [int(c) for c in str(bin(n))[2:]] # Return a path like "[1, 0, 1, ...]" for a node name in a perfect tree. def tree_path(self, n: int): # 1 is the root node. if n <= 1: n = 1 return self.bin_digits(n)[1:]
Below we create a tree and perform a random number of deletions, all at the root node.
@given(st.integers(min_value=1, max_value=30)) def test_delete_random_at_root(self, deletions: int): starting_size = 31 t = self.make_complete_tree(starting_size) # Repeatedly delete from the root node. for _ in range(deletions): t.delete([]) self.assertEqual(t.size(), starting_size - deletions)
The following test is similar, but we delete at random points in the tree.
@given(st.lists(st.integers(min_value=0, max_value=15), min_size=0, max_size=15)) def test_delete_random_at_random(self, deletion_paths: list[int]): starting_size = 15 t = self.make_complete_tree(starting_size) self.assertEqual(t.size(), 15) # Repeatedly delete from a random node in the tree. deleted_count = 0 for deletion_path in deletion_paths: path = self.tree_path(deletion_path) deleted_node = t.delete(path) # Only count the deletion if we actually deleted a node (maybe the path # was pointing to nothing). if deleted_node is not None: deleted_count += 1 self.assertEqual(t.size(), starting_size - deleted_count)
5. Export
from __future__ import annotations from typing import Any, Callable, Optional, List code