Dutch National Flag (3-way partitioning)

1. Problem statement

Given N objects colored red, white, or blue, sort them so that objects of the same colors are grouped together. The groups should be red, white, and blue. (Aziz et al., 2018, p. 43). Bonus: return the number of red and white objects, if any.

2. Insights

Instead of 3 colors, we can have any other kind of property. For example, we could say "less than the middle number", "equal to the middle number", and "greater than the middle number". Seen this way, the partition step of quicksort (where the "middle number" is called the pivot) is closely related.

If we just consider 2 colors, then this is essentially the same as the "Rearrange list" problem which partitioned a list of numbers into 2 groups, all even (if any) and odd (if any). There we used 2 indices to keep track of the left (even) and right (odd) groups, and we had 3 subgroups: even, unknown, odd. We shrank the unknown group to 0, and when this happened (when the even and odd indices crossed each other), we were done.

For arranging 3 colors, we could use 3 indices (red, white, blue), and keep track of 4 subgroups: red, white, unknown, blue. Everything starts as "unknown", but we whittle this down to 0 (when the white and blue indices cross).

3. Solution

For our solution, we don't use colors and instead use quicksort's naming scheme.

def dutch_flag_partition(nums, idx_pivot):
    if not nums:
        return 0, 0

    # If the given idx_pivot is out of bounds, just use 0. This is the pivot
    # *index*, so we will use the value (whatever it is) at index 0 as the
    # *pivot.
    if idx_pivot < 0 or idx_pivot > (len(nums) - 1):
        idx_pivot = 0

    pivot = nums[idx_pivot]
    # Track numbers less than pivot.
    idx_lt = 0
    # Track numbers equal to pivot.
    idx_eq = 0
    # Track numbers greater than pivot. Note that this index is out-of-bounds.
    idx_gt = len(nums)

    while idx_eq < idx_gt:
        if nums[idx_eq] == pivot:
            idx_eq += 1
        elif nums[idx_eq] < pivot:
            nums[idx_lt], nums[idx_eq] = nums[idx_eq], nums[idx_lt]
            idx_lt += 1
            idx_eq += 1
            idx_gt -= 1
            nums[idx_eq], nums[idx_gt] = nums[idx_gt], nums[idx_eq]

    # Return (1) the count of numbers less than the pivot, and (2) the count of
    # numbers equal to the pivot.

    return idx_lt, idx_eq - idx_lt

The while condition

    while idx_eq < idx_gt:

just checks whether we still have some nonzero size of unchecked numbers (in group "unknown").

As we iterate through the loop, we always check the number pointed by idx_eq (similar to how we always kept an eye on idx_odd in "Rearrange list").

Now let's consider the simplest case:

        if nums[idx_Eq] == pivot:
            idx_eq += 1

This is straightforward — if idx_eq is looking at a number that is equal to the pivot, we "count" it by incrementing idx_eq. This is the happy path.

Now let's examine the case where idx_eq is pointing to a number less than the pivot:

        elif nums[idx_eq] < pivot:
            nums[idx_lt], nums[idx_eq] = nums[idx_eq], nums[idx_lt]
            idx_lt += 1
            idx_eq += 1

In this case, idx_eq just found a number smaller than the pivot, so it "trades" it with idx_lt via a swap. We must increment idx_lt because it "grew" by 1 number.

But how come we increment idx_eq here? Don't we need to examine the number that was swapped in on the next iteration (without incrementing idx_eq)? To answer this question, let's look at the 2 possibilities for idx_lt and idx_eq: (1) these indices are equal (such as in the very first iteration of the loop when both are 0), or (2) idx_eq has gotten ahead of idx_lt by running the first if-branch when num[idx_eq] is equal to pivot. In both cases we want to increment idx_eq.

For the first case, the swap is a NOP. We want to increment idx_eq here because we want to examine the next number in the list.

For the second case, this means that at some point, idx_eq advanced on its own, leaving idx_lt behind. The only way that idx_eq can advance on its own is if nums[idx_eq] is equal to the pivot. This means that when idx_eq first got ahead of idx_lt, they were both pointing to a number P which was equal to the pivot (after which idx_eq advanced on its own, per the first if-branch). So when idx_eq sees a number smaller than the pivot, it knows that idx_lt, which fell behind, is already looking at a number equal to the pivot! So after the swap, idx_eq is pointing at P again (and so it must advance). This is a little tricky, so let's see an example.

Here we expect idx_lt and idx_eq to both advance in lockstep for some time, as we keep hitting 0 initially which is less than the pivot.

pivot = 1

0  0  0  1  0  0  1  (numbers)
0  1  2  3  4  5  6  (list index)
`- idx_lt
`- idx_eq

When we see 1, idx_eq advances on its own. Note how idx_lt "saves" an element equal to the pivot (by continuing to point to it) for a possible future swap with idx_eq.

0  0  0  1  0  0  1  (numbers)
0  1  2  3  4  5  6  (list index)
         `- idx_lt
         `- idx_eq

This is the critical step. idx_eq is looking at a number smaller than the pivot…

0  0  0  1  0  0  1  (numbers)
0  1  2  3  4  5  6  (list index)
         |  |
         `- idx_lt
            `- idx_eq

…so it performs a swap! Both idx_eq and idx_lt got what they wanted, respectively, so we must increment both of them.

0  0  0  0  1  0  1  (numbers)
0  1  2  3  4  5  6  (list index)
         |  |
         `- idx_lt
            `- idx_eq

So much for the interplay between idx_lt and idx_eq. Now we come to the final condition where idx_eq sees a number greater than the pivot:

            idx_gt -= 1
            nums[idx_eq], nums[idx_gt] = nums[idx_gt], nums[idx_eq]

In this case, we "throw over" the number to the right side of the list (to idx_gt), and in return get the number that idx_gt (whatever it may be) pointed to. Because of of-by-one issues, we have to decrement idx_gt first. Note that we do not increment idx_eq here, because we have no idea what idx_gt was pointing at previously.

4. Tests

from hypothesis import given, strategies as st
import unittest


class Test(unittest.TestCase):
    cases = [
        ([],                    0, [],                    0, 0),
        ([0],                   0, [0],                   0, 1),
        ([0, 0, 0, 0],          0, [0, 0, 0, 0],          0, 4),
        ([0, 1, 2, 0, 1, 2, 1], 1, [0, 0, 1, 1, 1, 2, 2], 2, 3),
        ([1, 1, 0],             1, [0, 1, 1],             1, 2),

    def test_simple_cases(self):
        for parts in self.cases:
            given_nums, idx_pivot, expected_nums, expected_lt, expected_eq = parts
            got_lt, got_eq = dutch_flag_partition(given_nums, idx_pivot)

            self.assertEqual(given_nums, expected_nums)
            self.assertEqual(got_lt, expected_lt)
            self.assertEqual(got_eq, expected_eq)

    @given(st.lists(st.integers(min_value=0, max_value=100), min_size=16,
            st.integers(min_value=0, max_value=15))
    def test_random(self, given_nums, idx_pivot):
        # Save the pivot now, because given_nums will get modified in-place (which
        # means given_nums[idx_pivot] could point to a different value later).
        pivot = given_nums[idx_pivot]
        got_lt, got_eq = dutch_flag_partition(given_nums, idx_pivot)
        len_given_nums = len(given_nums)

        # If all elements are the same, then that means that all of them are equal
        # to the pivot (no matter which idx_pivot we choose, we'll always end up
        # with a pivot that is the same as all other numbers in the list). This
        # means that the count of numbers equal to the pivot must be same as the
        # length of the entire list.
        if all(num == given_nums[0] for num in given_nums):
            self.assertEqual(got_eq, len_given_nums)

        # The subtotals of elements less than and equal to the pivot must never
        # exceed the length of the entire list.
        self.assertLessEqual(got_lt + got_eq, len_given_nums)

        # Check that the size of the subarray of numbers greater than the pivot
        # (which we can deduce by summing got_lt and got_eq) has
        # numbers that are indeed greater than the pivot.
        if got_lt + got_eq < len_given_nums:
            expected_gt = len_given_nums - got_lt - got_eq
            for i in range (len_given_nums - expected_gt, len_given_nums):
                self.assertGreater(given_nums[i], pivot)

if __name__ == "__main__":

5. Final remarks

The idea of partitioning is useful for implementing quicksort, because we can recursively call this function to sort a large list. There are two variations to do quicksort-style partitioning:

  1. (Dutch National Flag style, aka "fat pivot" (Bentley, 2000, p. 123)) elements equal to the pivot are gathered together in a single partitioning step, and
  2. (traditional quicksort) there are 2 partitions (elements less than, elements greater than) on each side of just a single "pivot" element).

The "fat pivot" style as presented here performs better in cases where there are a significant number of duplicate elements (Sedgewick & Wayne, 2011, p. 296). For example, if the entire list only has duplicates (e.g., every element is "7"), the traditional 3-way partitioning step of quicksort will still recurse repeatedly (because it only sorts 1 element (the pivot element) at a time), whereas the "fat pivot" style will not need te recurse at all as it will detect all duplicate elements in a single invocation.

A variation of the "fat pivot" algorithm puts the elements equal to the pivot on both the left and right ends of the array (Sedgewick & Wayne, 2011, p. 306), and does a final swap of these elements into the middle of the list (to sit between the less-than and greater-than elements). This variant "appears to be best in practice" (Knuth, 1997, p. 635), and appeared after "15 further years of experience" after Robert Sedgewick's analysis of traditional quicksort in the late 1970's (Knuth, 1997, p. 122).

(Cormen et al., 2009, p. 186) sadly only poses the Dutch National Flag as a problem and leaves the analysis of it mostly up to the reader.

(Skiena, 2008, pp. 125–126) gives a good intuitive explanation of traditional quicksort's performance characteristics.

