Inverting a Binary Tree

Binary Tree Inversion
DSA
data-structures
introduction
algorithms
python
binary-tree
Published

September 1, 2024

Inverting a Binary Tree

Intuition

The intuition for inverting a binary tree is to swap the left and right children of each node recursively or iteratively. We can visualize this as flipping the tree horizontally, where each node’s left child becomes its right child and vice versa. The key part for me, which may seem obvious, was realising that a binary tree is a data structure where each node has at most two child nodes.

Recursive Approach

Initially solved with a recursive approach based on the above intuition:
1. Start with the root node and define the base cases.
1. If the root is None (empty tree), return None.
2. If the root does not have 2 children, it is a leaf node -> return the root as is.
2. For non-leaf nodes, recursively invert the left and right subtrees.
3. Swap the inverted left and right subtrees.
4. Return the root of the inverted tree.

Complexity

  • Time complexity: O(n)
    We visit each node in the tree exactly once, where n is the number of nodes in the tree. Therefore, the time complexity is linear.

  • Space complexity: O(h)
    The space complexity is O(h), where h is the height of the tree. This is due to the recursive call stack. In the worst case (a completely unbalanced tree), this could be O(n), but for a balanced tree, it would be O(\log n).

Code

from typing import Optional

class TreeNode:
    def __init__(self, val: int=0, left: int|None = None, right: int|None =None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def invert_tree_recursive(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        
        if not root:
            return None
        if not root.right and not root.left:
            return root
        
        root.left, root.right = self.invert_tree_recursive(root.right), self.invert_tree_recursive(root.left)
        return root

def print_tree(node: TreeNode, level: int=0, prefix: str="Root: "):
    if node is not None:
        print(" " * (level * 4) + prefix + str(node.val))
        if node.left or node.right:
            print_tree(node.left, level + 1, "L--- ")
            print_tree(node.right, level + 1, "R--- ")

root = TreeNode(4)
root.left = TreeNode(2)
root.right = TreeNode(7)
root.left.left = TreeNode(1)
root.left.right = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(9)

print("Original Tree:")
print_tree(root)

solution = Solution()
inverted_root = solution.invert_tree_recursive(root)


print("\nInverted Tree:")
print_tree(inverted_root)
Original Tree:
Root: 4
    L--- 2
        L--- 1
        R--- 3
    R--- 7
        L--- 6
        R--- 9

Inverted Tree:
Root: 4
    L--- 7
        L--- 9
        R--- 6
    R--- 2
        L--- 3
        R--- 1

Iterative Approach

For an interative appraoch, a stack or queue works nicely as we can do the following:
1. Start with the root node and push it onto the stack
2. Take a node from the stack
3. Swap the left and right children
4. Add the non-None children to the stack for processing

Complexity

  • Time complexity: O(n)
    We visit each node in the tree exactly once, where n is the number of nodes in the tree, performing a constant time operation (swapping the nodes), and the while loop runs n times. Therefore, the time complexity is linear.

  • Space complexity: O(n)

    • This occurs in a “perfect” binary tree, where the last level is completely full.
    • At its maximum, the queue could contain all the leaf nodes, which in a perfect binary tree is (n+1)/2 nodes (slightly more than half of all nodes).
    • For a balanced tree, it could be O(w), where w is the maximum width of the tree (which is typically O(\log n) for a balanced tree).
    • For a skewed tree (e.g. a linked list), it would be O(1) as we’d only ever have one node in the queue at a time.

Code

from typing import Optional

class TreeNode:
    def __init__(self, val: int=0, left: int|None = None, right: int|None =None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def invert_tree_iterative(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        if not root:
            return None
        
        nodes_to_process = [root]
        while nodes_to_process:
            current = nodes_to_process.pop(0)
            current.left, current.right = current.right, current.left
            if current.left:
                nodes_to_process.append(current.left)
            if current.right:
                nodes_to_process.append(current.right)
        
        return root

def print_tree(node: TreeNode, level: int=0, prefix: str="Root: "):
    if node is not None:
        print(" " * (level * 4) + prefix + str(node.val))
        if node.left or node.right:
            print_tree(node.left, level + 1, "L--- ")
            print_tree(node.right, level + 1, "R--- ")

root = TreeNode(4)
root.left = TreeNode(2)
root.right = TreeNode(7)
root.left.left = TreeNode(1)
root.left.right = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(9)

print("Original Tree:")
print_tree(root)

solution = Solution()
inverted_root = solution.invert_tree_iterative(root)


print("\nInverted Tree:")
print_tree(inverted_root)
Original Tree:
Root: 4
    L--- 2
        L--- 1
        R--- 3
    R--- 7
        L--- 6
        R--- 9

Inverted Tree:
Root: 4
    L--- 7
        L--- 9
        R--- 6
    R--- 2
        L--- 3
        R--- 1

References

[^1] LeetCode 226 Invert Binary Tree