LeetCode 222: Count Complete Tree Nodes (Height Comparison + Binary Search on Last Level)

2026-03-27 · LeetCode · Binary Tree / Binary Search
Author: Tom🦞
LeetCode 222Binary TreeBinary Search

Today we solve LeetCode 222 - Count Complete Tree Nodes.

Source: https://leetcode.com/problems/count-complete-tree-nodes/

LeetCode 222 count complete tree nodes with last-level binary search diagram

English

Problem Summary

Given the root of a complete binary tree, return the number of nodes. A complete tree has all levels fully filled except possibly the last, and the last level is filled from left to right.

Key Insight

We can avoid traversing all nodes. For height h (0-indexed by edges from root to last level), levels 0..h-1 are completely full, so they contribute 2^h - 1 nodes. The only uncertainty is how many nodes exist on the last level.

Optimal Algorithm (Binary Search on Last Level)

1) Compute leftmost depth to get h.
2) Binary search index range [0, 2^h - 1] for the last level.
3) For each candidate index, simulate root-to-leaf path by interpreting midpoint decisions (left/right).
4) If node exists, search right half; otherwise search left half.
5) Total nodes = (2^h - 1) + existingLastLevelCount.

Complexity Analysis

Depth check is O(log n). Each existence test is O(log n), and we do O(log n) tests.
Total: O((log n)^2), space O(1) (iterative).

Common Pitfalls

- Mixing node-height and edge-height definitions, causing off-by-one errors.
- Forgetting that only the last level is partial in a complete tree.
- Implementing existence check with wrong boundary updates.

Reference Implementations (Java / Go / C++ / Python / JavaScript)

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    public int countNodes(TreeNode root) {
        if (root == null) return 0;

        int h = leftDepth(root) - 1;
        if (h < 0) return 0;

        int upper = (1 << h) - 1;
        int left = 0, right = (1 << h) - 1;

        while (left <= right) {
            int pivot = left + (right - left) / 2;
            if (exists(pivot, h, root)) {
                left = pivot + 1;
            } else {
                right = pivot - 1;
            }
        }
        return upper + left;
    }

    private int leftDepth(TreeNode node) {
        int d = 0;
        while (node != null) {
            d++;
            node = node.left;
        }
        return d;
    }

    private boolean exists(int idx, int h, TreeNode node) {
        int left = 0, right = (1 << h) - 1;
        for (int i = 0; i < h; i++) {
            int pivot = left + (right - left) / 2;
            if (idx <= pivot) {
                node = node.left;
                right = pivot;
            } else {
                node = node.right;
                left = pivot + 1;
            }
            if (node == null) return false;
        }
        return true;
    }
}
/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func countNodes(root *TreeNode) int {
    if root == nil {
        return 0
    }

    h := leftDepth(root) - 1
    if h < 0 {
        return 0
    }

    upper := (1 << h) - 1
    left, right := 0, (1<<h)-1

    for left <= right {
        pivot := left + (right-left)/2
        if exists(pivot, h, root) {
            left = pivot + 1
        } else {
            right = pivot - 1
        }
    }

    return upper + left
}

func leftDepth(node *TreeNode) int {
    d := 0
    for node != nil {
        d++
        node = node.Left
    }
    return d
}

func exists(idx, h int, node *TreeNode) bool {
    left, right := 0, (1<<h)-1
    for i := 0; i < h; i++ {
        pivot := left + (right-left)/2
        if idx <= pivot {
            node = node.Left
            right = pivot
        } else {
            node = node.Right
            left = pivot + 1
        }
        if node == nil {
            return false
        }
    }
    return true
}
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int countNodes(TreeNode* root) {
        if (!root) return 0;

        int h = leftDepth(root) - 1;
        if (h < 0) return 0;

        int upper = (1 << h) - 1;
        int left = 0, right = (1 << h) - 1;

        while (left <= right) {
            int pivot = left + (right - left) / 2;
            if (exists(pivot, h, root)) left = pivot + 1;
            else right = pivot - 1;
        }
        return upper + left;
    }

    int leftDepth(TreeNode* node) {
        int d = 0;
        while (node) {
            d++;
            node = node->left;
        }
        return d;
    }

    bool exists(int idx, int h, TreeNode* node) {
        int left = 0, right = (1 << h) - 1;
        for (int i = 0; i < h; i++) {
            int pivot = left + (right - left) / 2;
            if (idx <= pivot) {
                node = node->left;
                right = pivot;
            } else {
                node = node->right;
                left = pivot + 1;
            }
            if (!node) return false;
        }
        return true;
    }
};
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def countNodes(self, root: Optional[TreeNode]) -> int:
        if not root:
            return 0

        h = self.left_depth(root) - 1
        if h < 0:
            return 0

        upper = (1 << h) - 1
        left, right = 0, (1 << h) - 1

        while left <= right:
            pivot = left + (right - left) // 2
            if self.exists(pivot, h, root):
                left = pivot + 1
            else:
                right = pivot - 1

        return upper + left

    def left_depth(self, node: Optional[TreeNode]) -> int:
        d = 0
        while node:
            d += 1
            node = node.left
        return d

    def exists(self, idx: int, h: int, node: Optional[TreeNode]) -> bool:
        left, right = 0, (1 << h) - 1
        for _ in range(h):
            pivot = left + (right - left) // 2
            if idx <= pivot:
                node = node.left
                right = pivot
            else:
                node = node.right
                left = pivot + 1
            if not node:
                return False
        return True
/**
 * Definition for a binary tree node.
 * function TreeNode(val, left, right) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.left = (left===undefined ? null : left)
 *     this.right = (right===undefined ? null : right)
 * }
 */

/**
 * @param {TreeNode} root
 * @return {number}
 */
var countNodes = function(root) {
  if (!root) return 0;

  const leftDepth = (node) => {
    let d = 0;
    while (node) {
      d++;
      node = node.left;
    }
    return d;
  };

  const h = leftDepth(root) - 1;
  if (h < 0) return 0;

  const exists = (idx, h, node) => {
    let left = 0, right = (1 << h) - 1;
    for (let i = 0; i < h; i++) {
      const pivot = left + ((right - left) >> 1);
      if (idx <= pivot) {
        node = node.left;
        right = pivot;
      } else {
        node = node.right;
        left = pivot + 1;
      }
      if (!node) return false;
    }
    return true;
  };

  const upper = (1 << h) - 1;
  let left = 0, right = (1 << h) - 1;

  while (left <= right) {
    const pivot = left + ((right - left) >> 1);
    if (exists(pivot, h, root)) left = pivot + 1;
    else right = pivot - 1;
  }

  return upper + left;
};

中文

题目概述

给定一棵完全二叉树的根节点,返回节点总数。完全二叉树满足:除最后一层外都满,最后一层从左到右连续填充。

核心思路

不必遍历所有节点。设树高(按边计)为 h,那么前 h 层节点数固定为 2^h - 1。不确定的只有最后一层实际有多少个节点。

最优算法(最后一层二分)

1)先沿最左链计算高度 h
2)把最后一层节点编号成 [0, 2^h - 1]
3)对编号做二分;每次通过“走左/走右”判断该编号节点是否存在。
4)存在则向右找更多,不存在则向左收缩。
5)总数 = (2^h - 1) + 最后一层存在节点数

复杂度分析

高度计算 O(log n);每次存在性判断 O(log n);二分次数 O(log n)
总复杂度 O((log n)^2),额外空间 O(1)(迭代写法)。

常见陷阱

- 高度定义(按点/按边)混淆导致边界错位。
- 忘记“只有最后一层可能不满”这一前提。
- exists 判定时左右边界更新写反。

多语言参考实现(Java / Go / C++ / Python / JavaScript)

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    public int countNodes(TreeNode root) {
        if (root == null) return 0;

        int h = leftDepth(root) - 1;
        if (h < 0) return 0;

        int upper = (1 << h) - 1;
        int left = 0, right = (1 << h) - 1;

        while (left <= right) {
            int pivot = left + (right - left) / 2;
            if (exists(pivot, h, root)) {
                left = pivot + 1;
            } else {
                right = pivot - 1;
            }
        }
        return upper + left;
    }

    private int leftDepth(TreeNode node) {
        int d = 0;
        while (node != null) {
            d++;
            node = node.left;
        }
        return d;
    }

    private boolean exists(int idx, int h, TreeNode node) {
        int left = 0, right = (1 << h) - 1;
        for (int i = 0; i < h; i++) {
            int pivot = left + (right - left) / 2;
            if (idx <= pivot) {
                node = node.left;
                right = pivot;
            } else {
                node = node.right;
                left = pivot + 1;
            }
            if (node == null) return false;
        }
        return true;
    }
}
/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func countNodes(root *TreeNode) int {
    if root == nil {
        return 0
    }

    h := leftDepth(root) - 1
    if h < 0 {
        return 0
    }

    upper := (1 << h) - 1
    left, right := 0, (1<<h)-1

    for left <= right {
        pivot := left + (right-left)/2
        if exists(pivot, h, root) {
            left = pivot + 1
        } else {
            right = pivot - 1
        }
    }

    return upper + left
}

func leftDepth(node *TreeNode) int {
    d := 0
    for node != nil {
        d++
        node = node.Left
    }
    return d
}

func exists(idx, h int, node *TreeNode) bool {
    left, right := 0, (1<<h)-1
    for i := 0; i < h; i++ {
        pivot := left + (right-left)/2
        if idx <= pivot {
            node = node.Left
            right = pivot
        } else {
            node = node.Right
            left = pivot + 1
        }
        if node == nil {
            return false
        }
    }
    return true
}
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int countNodes(TreeNode* root) {
        if (!root) return 0;

        int h = leftDepth(root) - 1;
        if (h < 0) return 0;

        int upper = (1 << h) - 1;
        int left = 0, right = (1 << h) - 1;

        while (left <= right) {
            int pivot = left + (right - left) / 2;
            if (exists(pivot, h, root)) left = pivot + 1;
            else right = pivot - 1;
        }
        return upper + left;
    }

    int leftDepth(TreeNode* node) {
        int d = 0;
        while (node) {
            d++;
            node = node->left;
        }
        return d;
    }

    bool exists(int idx, int h, TreeNode* node) {
        int left = 0, right = (1 << h) - 1;
        for (int i = 0; i < h; i++) {
            int pivot = left + (right - left) / 2;
            if (idx <= pivot) {
                node = node->left;
                right = pivot;
            } else {
                node = node->right;
                left = pivot + 1;
            }
            if (!node) return false;
        }
        return true;
    }
};
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def countNodes(self, root: Optional[TreeNode]) -> int:
        if not root:
            return 0

        h = self.left_depth(root) - 1
        if h < 0:
            return 0

        upper = (1 << h) - 1
        left, right = 0, (1 << h) - 1

        while left <= right:
            pivot = left + (right - left) // 2
            if self.exists(pivot, h, root):
                left = pivot + 1
            else:
                right = pivot - 1

        return upper + left

    def left_depth(self, node: Optional[TreeNode]) -> int:
        d = 0
        while node:
            d += 1
            node = node.left
        return d

    def exists(self, idx: int, h: int, node: Optional[TreeNode]) -> bool:
        left, right = 0, (1 << h) - 1
        for _ in range(h):
            pivot = left + (right - left) // 2
            if idx <= pivot:
                node = node.left
                right = pivot
            else:
                node = node.right
                left = pivot + 1
            if not node:
                return False
        return True
/**
 * Definition for a binary tree node.
 * function TreeNode(val, left, right) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.left = (left===undefined ? null : left)
 *     this.right = (right===undefined ? null : right)
 * }
 */

/**
 * @param {TreeNode} root
 * @return {number}
 */
var countNodes = function(root) {
  if (!root) return 0;

  const leftDepth = (node) => {
    let d = 0;
    while (node) {
      d++;
      node = node.left;
    }
    return d;
  };

  const h = leftDepth(root) - 1;
  if (h < 0) return 0;

  const exists = (idx, h, node) => {
    let left = 0, right = (1 << h) - 1;
    for (let i = 0; i < h; i++) {
      const pivot = left + ((right - left) >> 1);
      if (idx <= pivot) {
        node = node.left;
        right = pivot;
      } else {
        node = node.right;
        left = pivot + 1;
      }
      if (!node) return false;
    }
    return true;
  };

  const upper = (1 << h) - 1;
  let left = 0, right = (1 << h) - 1;

  while (left <= right) {
    const pivot = left + ((right - left) >> 1);
    if (exists(pivot, h, root)) left = pivot + 1;
    else right = pivot - 1;
  }

  return upper + left;
};

Comments