Height-balanced binary trees

1. Problem statement

Check if a binary tree is height-balanced (Aziz et al., 2018, p. 124).

2. Insights

2.1. Height-balaced

The height of a tree is simply how tall a tree is, or more precisely the deepest depth of any node in the tree. A tree is height-balanced if all nodes' child trees have a height difference of at most 1, such as in the example below.

binary_tree_height_balanced.svg

Figure 1: An elaborate, but still height-balanced binary tree.

3. Solution

3.1. Brute force

We reuse the binary tree library we created earlier.

Basically, we just compute height of trees rooted at every node. Then it's just a matter of traversing the entire tree and making sure that the two subtrees below us are balanced (height delta is less than or equal to 1).

We can traverse through the tree so that we visit every single node exactly once. When visiting the node, we treat that node as the root of a new binary tree. Then all we have to do is get the height of the left and right subtrees, and see if the difference is greater than 1. The missing building block here is an algorithm to find the height of a tree (which we can apply to subtrees just as well).

First let's figure out how to compute the height of a tree. We can solve this using recursion — return -1 if a Node is None (base case), but otherwise recurse down to each child. When computing the height, get the maximum of each child and add 1. Adding 1 is required because it is what we add up at each level.

code(1/6) 🔗
from binary_tree.binary_tree import BinaryTree, Node

def height(x: Optional[Node]) -> int:
    if x is None:
        return -1

    height_l = height(x.left)
    height_r = height(x.right)

    return max(height_l, height_r) + 1

Now that we can determine the height of a tree rooted at any particular node, we are ready to traverse through the tree and run the height computation for every single node. As soon as we determine that the tree is not balanced, we return False.

code(2/6) 🔗
def is_height_balanced__brute_force(x: Optional[Node]) -> bool:
    # If there is no node, then just consider it height-balanced. This also
    # matters when we recurse --- we are bound to run into a None object (when
    # looking at the children of leaf nodes), and we don't want to abort the
    # traversal as soon as we encounter them. So we are forced to return True.
    if x is None:
        return True

    # Traverse into the other nodes, making sure that we call the body of this
    # function exactly once for every Node.
    if not is_height_balanced__brute_force(x.left):
        return False
    if not is_height_balanced__brute_force(x.right):
        return False

    # "Visit" this node by doing the actual subtree height check for the left
    # and right children.
    height_l = height(x.left)
    height_r = height(x.right)

    big = max(height_l, height_r)
    small = min(height_l, height_r)
    delta = big - small

    return delta <= 1

We do postorder traversal because the code looks aesthically simpler written this way (because the "visiting" bit comes at the end, after the "traversal" part at the top).

Obviously this is wasteful because we recompute the same result multiple times (that is, we end up calling height() multiple times for the same node). The time complexity is \(O(n^2)\).

For example, we have a degenerate tree that only has nodes in the left subtree, resembling a linked list. For each node, calling the height() function would require visiting \(d\) nodes, where \(d\) is the depth of the "tree" at the current node.

binary_tree_degenerate.svg

Figure 2: A degenerate tree that looks like a linked list.

A B C D E # determine height of tree at A
  B C D E # determine height of tree at B
    C D E # determine height of tree at C
      D E # determine height of tree at D
        E # determine height of tree at E

This is very similar to the trace of bubble sort, which must "traverse" each element in an array \(n\) times, where each time it starts the iteration with the next element in the array.

1 2 3 4 5 # iterate through all elements, 1 to 5
  2 3 4 5 # iterate through all elements, 2 to 5
    3 4 5 # iterate through all elements, 3 to 5
      4 5 # iterate through all elements, 4 to 5
        5 # iterate through all elements, 5 to 5

And so \(O(n^2)\) is the case for degenerate trees. What about other types of trees? Well, regardless of the structure of the tree, every time we increase the number of nodes \(n\) in the tree by 1, we need to traverse exactly more additional node. So the growth behavior doesn't change and we are still left with \(O(n^2)\) time complexity.

3.1.1. Other traversals

Does it matter if the brute force approach used a different type of traversal? No. The code is harder to read though because the traversal logic is intermixed with the "what to do with our current node" logic.

code(3/6) 🔗
def is_height_balanced__brute_force_preorder(x: Optional[Node]) -> bool:
    if x is None:
        return True

    height_l = height(x.left)
    height_r = height(x.right)

    big = max(height_l, height_r)
    small = min(height_l, height_r)
    delta = big - small

    if delta > 1:
        return False

    if not is_height_balanced__brute_force(x.left):
        return False
    if not is_height_balanced__brute_force(x.right):
        return False

    return True

def is_height_balanced__brute_force_inorder(x: Optional[Node]) -> bool:
    if x is None:
        return True

    if not is_height_balanced__brute_force(x.left):
        return False

    height_l = height(x.left)
    height_r = height(x.right)

    big = max(height_l, height_r)
    small = min(height_l, height_r)
    delta = big - small

    if delta > 1:
        return False

    if not is_height_balanced__brute_force(x.right):
        return False

    return True

3.2. Optimal

The optimal solution avoids using height(), and instead relies on the property of postorder traversals themselves. In postorder traversal, the leaf nodes of a subtree are always visited first, and the left subtree is fully visited before visiting the right subtree. And both the left and right subtrees are fully resolved (visited) before we visit the current node.

Using this property, we can pass the computed height back down the call stack (back up the tree), so that we can compute it in constant time. At any time we determine that the subtree is not balanced, we return early. If both subtrees and the current subtree is balanced, we return (True, <height>) where <height> is the height of the current tree, determined by taking the maximum value of either subtree and adding 1.

Let's use a named tuple for this to improve readability.

code(4/6) 🔗
class NodeInfo(NamedTuple):
    balanced: bool
    height: int

Now let's consider first writing a helper function. The reason is because we want to match the type signature of the brute force approach (we want to return just True or False as the output), but our recursive computation as described above requires us to return both the boolean and a height.

code(5/6) 🔗
def get_balanced_status(x: Optional[Node]) -> NodeInfo:
    # Base case for recursion, just like for the brute force approach.
    if x is None:
        return NodeInfo(True, -1)

    # Early return if either the left or right subtrees were not balanced.
    ni_left = get_balanced_status(x.left)
    if not ni_left.balanced:
        return ni_left

    ni_right = get_balanced_status(x.right)
    if not ni_right.balanced:
        return ni_right

    # Both the left and right subtrees are balanced. But it could be that
    # the tree rooted at the current node is not balanced.
    big = max(ni_left.height, ni_right.height)
    small = min(ni_left.height, ni_right.height)
    delta = big - small

    if delta > 1:
        return NodeInfo(False, -1)

    # Current tree is balanced. The height of the tree is the maximum of the
    # heights of either subtree, plus 1.
    return NodeInfo(True, max(ni_left.height, ni_right.height) + 1)

With the heavy lifting out of the way, our entrypoint function can just return the first boolean ("balanced") field.

code(6/6) 🔗
def is_height_balanced__optimal(x: Optional[Node]) -> bool:
    return get_balanced_status(x).balanced

In summary, we basically do a single postorder traversal, and during this traversal, try to return early if we detect that at any point the height is not balanced.

Time complexity is \(O(n)\) where \(n\) is the number of nodes in the tree. The "space" complexity is \(O(h)\) where \(h\) is the height of the tree; here "space" means the amount of space used by the function call stack as we undergo recursion. For memory usage, space complexity is \(O(1)\) because there are never more than three NodeInfo variables (ni_left, ni_right, and the NodeInfo object returned by get_balanced_status) needed for allocation at any given time, regardless of how many nodes are in the entire tree.

4. Tests

🔗
from __future__ import annotations
from hypothesis import given, strategies as st
from typing import Optional, NamedTuple
import unittest

code

class Test(unittest.TestCase):
    def test_height(self):
        t = BinaryTree()
        self.assertEqual(height(t.root), -1)

        t = BinaryTree()
        t.insert(0, [])
        self.assertEqual(height(t.root), 0)

        t = BinaryTree()
        t.insert(0, [0, 0, 0])
        self.assertEqual(height(t.root), 3)
        self.assertEqual(height(t.root.left), 2)
        self.assertEqual(height(t.root.left.left), 1)
        self.assertEqual(height(t.root.left.left.left), 0)
        self.assertEqual(height(t.root.left.left.left.left), -1)

        t = BinaryTree()
        t.insert(0, [0, 0, 0])
        t.insert(0, [1, 1, 1])
        self.assertEqual(height(t.root), 3)

        t = BinaryTree()
        t.insert(0, [0])
        t.insert(0, [1, 1, 1])
        self.assertEqual(height(t.root), 3)

    def test_is_height_balanced(self):
        t = BinaryTree()
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        t = BinaryTree()
        t.insert(0, [])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        t = BinaryTree()
        t.insert(0, [])
        t.insert(0, [0])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        t = BinaryTree()
        t.insert(0, [])
        t.insert(0, [0])
        t.insert(0, [1])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        t = BinaryTree()
        t.insert(0, [])
        t.insert(0, [0, 0])
        t.insert(0, [1])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        t = BinaryTree()
        t.insert(0, [])
        t.insert(0, [0, 0])
        t.insert(0, [1, 1])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

        # Trivial unbalanced example.
        t = BinaryTree()
        t.insert(0, [])
        t.insert(0, [0, 0, 0])
        t.insert(0, [1])
        self.assertEqual(is_height_balanced__brute_force(t.root), False)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), False)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), False)
        self.assertEqual(is_height_balanced__optimal(t.root), False)

        # Elaborate balanced example (same as in the diagram in Figure 1).
        t = BinaryTree()
        t.insert(1, [])
        t.insert(2, [0])
        t.insert(3, [0, 0])
        t.insert(8, [0, 1])
        t.insert(9, [0, 1, 0])
        t.insert(10, [0, 1, 1])
        t.insert(4, [0, 0, 0])
        t.insert(7, [0, 0, 1])
        t.insert(5, [0, 0, 0, 0])
        t.insert(6, [0, 0, 0, 1])
        t.insert(11, [1])
        t.insert(12, [1, 0])
        t.insert(13, [1, 0, 0])
        t.insert(14, [1, 0, 1])
        t.insert(15, [1, 1])
        self.assertEqual(is_height_balanced__brute_force(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_preorder(t.root), True)
        self.assertEqual(is_height_balanced__brute_force_inorder(t.root), True)
        self.assertEqual(is_height_balanced__optimal(t.root), True)

    # Populate a random binary tree. Check that the optimal solution agrees with
    # the brute force approach.
    @given(st.lists(st.lists(st.integers(min_value=0, max_value=1),
                            min_size=0,
                            max_size=32),
                    min_size=0,
                    max_size=32))
    def test_is_height_balanced__random(self, paths: list[list[int]]):
        t = BinaryTree()
        for path in paths:
            t.insert(0, path)

        bf = is_height_balanced__brute_force(t.root)
        bf_pre = is_height_balanced__brute_force_preorder(t.root)
        bf_in = is_height_balanced__brute_force_inorder(t.root)
        o = is_height_balanced__optimal(t.root)

        self.assertEqual(bf, bf_pre)
        self.assertEqual(bf, bf_in)
        self.assertEqual(bf, o)

if __name__ == "__main__":
    unittest.main(exit=False)

5. References

Aziz, A., Lee, T.-H., & Prakash, A. (2018). Elements of Programming Interviews in Python: The Insiders’ Guide. CreateSpace Independent Publishing Platform (25 July. 2018).