Validate binary tree as BST

1. Problem statement

Check if a given binary tree is also (already) a binary search tree (BST) (Aziz et al., 2018, p. 213).

2. Insights

2.1. Acceptable ranges

Each node holds a value, and there is a special relationship between the parent and its left and right children. Namely, the left child cannot hold any value that is greater than the parent. Likewise, the right child cannot hold a value that is smaller than the parent value. These ranges (called "intervals" in math) must be respected by all subsequent child nodes.

bst_subtree_ranges.svg

Figure 1: Ranges for BST subtrees

In Figure 1, node A can contain nodes from negative infinity (assuming we have infinite memory) all the way up to 24. Similar rules apply for all other nodes. We can summarise the acceptable interval ranges as follows:

\begin{align*} A &= ({-\infty},{25}) \\ B &= (25,50) \\ C &= (50,75) \\ D &= (75, {\infty}). \end{align*}

In Aziz et al. (2018, p. 213), the authors use closed interval ranges, such that the above would instead look like this:

\begin{align*} A &= ({-\infty},{25}] \\ B &= [25,50] \\ C &= [50,75] \\ D &= [75, {\infty}). \end{align*}

The square brackets denote that the number it is next to is included in the range. For example, assuming we're using integers and not real numbers, \([25,50]\) means the range [25, 26, …, 50], whereas \((25,50)\) means the range [26, 27, …, 49].

But, we don't use this definition because, for simplicity, we assume that our BST does not allow the same key to appear on different nodes (i.e., duplicate keys are not allowed in our BST).

3. Solution

We'll be using our binary tree implementation to solve this problem.

3.1. Brute force?

In Aziz et al. (2018, p. 213), the authors describe a brute force approach where we essentially pick out the maximum or minimum values in each node of the tree (treating each node as its own independent binary tree), and then check that those minimums and maximums are in line with the value at the corresponding root node.

Ease of implementation, and the ensuing trust that the code is obviously free of bugs, is the most important consideration for any brute force solution.

This could be considered "brute force" but it seems overly contrived. The problem with this approach is that it appears more difficult to code than the better alternatives presented in the text. Therefore we skip implementing this solution.

3.2. DFS

3.2.1. By recording traversal history

A fundamental property of BSTs is that if you do an in-order traversal (depth-first search (DFS) where we visit the left child, current node, then the right child, recursively), you'll get keys in sorted order. So then we can do an in-order traversal, checking if we are visiting nodes in equal or increasing value than the previously visited node.

We write x.val to get the "key" of the node, because this is a regular BinaryTree, not a BinarySearchTree that has both key and val fields. For a discussion about why we need keys and values for BSTs but not regular binary trees, see the discussion about binary search trees.

def is_sorted(lst: List[int]) -> bool:
    return all(a <= b for a, b in zip(lst, lst[1:]))

def dfs_inorder(t: BinaryTree) -> bool:
    traversal_history = []
    def record(x: Node):
        traversal_history.append(x.val)
    t.traverse_inorder(record)

    return is_sorted(traversal_history)

The time complexity is \(O(n)\) where \(n\) is the number of nodes in the tree. This version needs additional \(O(n)\) space complexity because it stores all nodes it visits before doing the check with is_sorted().

3.2.2. Without recording traversal history

def dfs_inorder_manual(t: BinaryTree) -> bool:
    if t.root is None:
        return True

    last = None
    def helper(x: Optional[Node]) -> bool:
        nonlocal last

        if x is None:
            return True

        if not helper(x.left):
            return False

        if last is not None and last > x.val:
            return False
        last = x.val

        if not helper(x.right):
            return False

        return True

    return helper(t.root)

This solution improves on the space complexity, by using a fixed last variable. So the space complexity drops to just \(O(1)\).

Sadly, problem with both dfs_inorder_manual() and dfs_inorder() is that they will traverse the left subtree completely before traversing the right subtree. So if the problematic node is near the top of the tree on the right, we might potentially waste a lot of CPU cycles before detecting the (obvious) issue.

3.2.3. Using ranges

As mentioned in 2, the key to this problem is to check if each node satisfies the range constraints imposed by the parent nodes. We can do this with a depth-first search (DFS) by recursing down the subtrees and checking that the nodes encountered fall within the expected range. As we descend down the tree, we continually refine our intervals as needed. When we are done traversing the entire tree, we're done.

The DFS should be a preorder traversal, because we should check the node that we are currently on as soon as possible before descending down to the subtrees.

def dfs_preorder_range(t: BinaryTree) -> bool:
    def helper(x: Optional[Node], lo=None, hi=None) -> bool:
        if x is None:
            return True

        # Root node. Check children directly, because our starting ranges lo and
        # hi are both None (rendering our checks useless).
        if not lo and not hi:
            if x.left and not x.left.val < x.val:
                return False
            if x.right and not x.val < x.right.val:
                return False

        # Non-root node.
        if lo and not lo < x.val:
            return False
        if hi and not x.val < hi:
            return False
        return (helper(x.left, lo, x.val) and
                helper(x.right, x.val, hi))

    return helper(t.root)

While we are using the ranges idea here, we still perform DFS where the left subtree is computed first. And so it suffers from the same issues described for the DFS solution described in 3.2.2 above.

3.3. Optimal (BFS)

Using breadth-first search (BFS), we can check each level of the binary tree, top to bottom. For each node we visit, we just check that its key fits within the range that we expect it to be in.

We also simplify the code a bit by using floats instead of integers, which lets us use negative and positive infinity. This avoids the extra stanza we need for the root node if we're only using integer types and using None to represent these infinite values.

def bfs(t: BinaryTree) -> bool:

    class QueueEntry(NamedTuple):
        node: Optional[Node]
        lo: float
        hi: float

    q = collections.deque([QueueEntry(t.root, float("-inf"), float("inf"))])

    while q:
        x, lo, hi = q.popleft()

        if x:
            if not lo < x.val < hi:
                return False

            q.append(QueueEntry(x.left, lo, x.val))
            q.append(QueueEntry(x.right, x.val, hi))

    return True

Time complexity is still \(O(n)\), but we solve the problem of traversing the left subtree needlessly in situations where we only need to check something in the right subtree.

4. Tests

🔗
from binary_tree.binary_tree import BinaryTree, Node
import binary_search_tree.binary_search_tree as BST
import collections
from hypothesis import given, strategies as st
from typing import List, NamedTuple, Optional
import unittest

solution

# Utility function for property-based tests.
bst_converter

class Test(unittest.TestCase):
    test_cases

if __name__ == "__main__":
    unittest.main(exit=False)

4.1. Basic tests

def test_basic(self):
    # Empty tree is a valid BST.
    t = BinaryTree()
    result = dfs_inorder(t)
    self.assertEqual(result, True)

    # Basic happy case.
    t = BinaryTree()
    t.insert(50, [])
    t.insert(25, [0])
    t.insert(75, [1])
    result = dfs_inorder(t)
    self.assertEqual(result, True)

    # Basic unhappy cases.
    t = BinaryTree()
    t.insert(1, [])
    t.insert(0, [1])
    result = dfs_inorder(t)
    self.assertEqual(result, False)
    t = BinaryTree()
    t.insert(100, [])
    t.insert(25, [0])
    t.insert(75, [1])
    result = dfs_inorder(t)
    self.assertEqual(result, False)

4.2. Property-based tests

First we need a converter that can convert a BinarySearchTree into a BinaryTree.

By using our binary search tree implementation, we can confidently generate any random number of BSTs, convert them to regular binary trees, and then check whether they pass our validation functions.

# Convert BST into a binary tree. Of the "key" and "val" of each BST node, we
# only preserve the "key".
def convert_keys_to_bt(bst: BST.BinarySearchTree) -> BinaryTree:
    bt = BinaryTree()

    if bst.root is None:
        return bt

    bt.root = Node()

    def helper(x: Optional[BST.Node], cursor: Optional[Node]):
        if x is None:
            return
        if cursor is None:
            return

        # A binary tree node's value is actually the "key" of the BST.
        cursor.val = x.key

        if x.left is not None:
            cursor.left = Node()
            helper(x.left, cursor.left)

        if x.right is not None:
            cursor.right = Node()
            helper(x.right, cursor.right)

    helper(bst.root, bt.root)

    return bt

Let's check that our converter works, by doing the same preorder traversal and checking that we get the same values (keys for BSTs) back for both trees.

@given(st.lists(st.integers(min_value=1, max_value=50),
                min_size=1,
                max_size=50))
def test_converter(self, keys: List[int]):
    bst = BST.BinarySearchTree()
    for key in keys:
        bst.insert(key)

    traversal_history_bst = []
    def record_traversal_history_bst(x: BST.Node):
        traversal_history_bst.append(x.key)
    bst.traverse_preorder(record_traversal_history_bst)

    bt = convert_keys_to_bt(bst)

    traversal_history_bt = []
    def record_traversal_history_bt(x: Node):
        traversal_history_bt.append(x.val)
    bt.traverse_preorder(record_traversal_history_bt)

    self.assertEqual(traversal_history_bt, traversal_history_bst)

Generate random BSTs, and convert them to binary trees. Our validation functions are all expected to pass for these inputs.

@given(st.lists(st.integers(min_value=1, max_value=50),
                min_size=1,
                max_size=50))
def test_random_BSTs_are_all_valid_BSTs(self, keys: List[int]):
    bst = BST.BinarySearchTree()
    for key in keys:
        bst.insert(key)

    bt = convert_keys_to_bt(bst)

    self.assertEqual(True, dfs_inorder(bt))
    self.assertEqual(True, dfs_inorder_manual(bt))
    self.assertEqual(True, dfs_preorder_range(bt))
    self.assertEqual(True, bfs(bt))

Now let's check the opposite. Construct BSTs. Convert them to binary trees. But then tweak one of the nodes in the tree at random to be out of line to break the BST property. All such trees should fail BST validation.

Note that breaking the BST property involves choosing a very high or very low value. Even if we know that our BST keys are all positive numbers, we cannot always choose a minimum value like "-1" because if we assign this value to the leftmost child, the tree will still be valid.

Note that all BSTs generated here should be at least 2 nodes, because the root node can be of any value.

@given(st.lists(st.integers(min_value=1, max_value=50),
                min_size=2, # The root must have at least 1 child.
                max_size=50,
                unique=True),
       st.randoms())
def test_invalid_BSTs(self, keys: List[int], rand):
    bst = BST.BinarySearchTree()
    for key in keys:
        bst.insert(key)

    bt = convert_keys_to_bt(bst)

    # Pick a random index. This index will be used to pick a node in the binary
    # tree at random, to assign a BST-property-breaking value. Because our BST
    # keys are in the range 1 to 50, we choose 0 and 51 as the special
    # BST-property-breaking values.
    chosen_idx = rand.randint(0, bst.size() - 1)

    bad_lo = 0
    bad_hi = 51
    seen_idx = 0
    stop = False

    # direction_hist is a 2-bit number, the first bit meaning "I went left", and
    # the second bit meaning "I went right". We use this to check the overall
    # history of which child nodes we've followed to reach the current node
    # under consideration.
    def mutate(x: Optional[Node], direction_hist: int):
        nonlocal seen_idx
        nonlocal stop

        if not x:
            return

        if stop:
            return

        if seen_idx == chosen_idx:
            # Our mutate() function depends on direction_hist, but for the root
            # node this will always be empty. Sometimes the root node may only
            # have one child, instead of two (as basic BSTs are not
            # self-balancing). So if we chose the root node to mutate, take care
            # to check if we only have one child, and choose the bad value
            # accordingly.
            if chosen_idx == 0:
                if x.left is None:
                    x.val = bad_hi
                else:
                    x.val = bad_lo
            else:
                # Non-root node.

                # We only went left so far. There is a chance that we're at the
                # leftmost child node. In all other cases setting a min value (0)
                # would work to break this tree's BST property, but not for this
                # leftmost leaf node. So assign the opposite (just beyond max)
                # value.
                if direction_hist == 1:
                    x.val = bad_hi
                # If we've visited both left and child nodes so far, then we can
                # assign either bad_lo or bad_hi --- it doesn't matter. This
                # else-condition also covers the case where we only went right, so
                # our hand is forced to choose bad_lo here.
                else:
                    x.val = bad_lo

            # We've performed the mutation we wanted, so stop traversing the
            # tree.
            stop = True

        seen_idx += 1

        mutate(x.left, direction_hist | 1)
        mutate(x.right, direction_hist | 2)

    mutate(bt.root, 0)

    self.assertEqual(False, dfs_inorder(bt))
    self.assertEqual(False, dfs_inorder_manual(bt))
    self.assertEqual(False, dfs_preorder_range(bt))
    self.assertEqual(False, bfs(bt))

5. References

Aziz, A., Lee, T.-H., & Prakash, A. (2018). Elements of Programming Interviews in Python: The Insiders’ Guide. CreateSpace Independent Publishing Platform (25 July. 2018).