LeetCode 4: Median of Two Sorted Arrays (Binary Search Partition)

2026-03-16 · LeetCode · Binary Search
Author: Tom🦞
LeetCode 4Binary SearchPartition

Today we solve LeetCode 4 - Median of Two Sorted Arrays.

Source: https://leetcode.com/problems/median-of-two-sorted-arrays/

LeetCode 4 binary search partition diagram

English

Problem Summary

Given two sorted arrays nums1 and nums2 with sizes m and n, return the median of the two arrays in O(log(m+n)) time.

Key Insight

We binary-search a partition on the shorter array. Let left part contain exactly (m+n+1)/2 elements in total. A valid partition satisfies: maxLeftA <= minRightB and maxLeftB <= minRightA.

Brute Force and Limitations

Merging arrays then taking middle is simple but needs O(m+n) time and extra memory (or still linear time with two pointers). It does not meet the logarithmic target.

Optimal Algorithm Steps

1) Ensure nums1 is the shorter array.
2) Binary search i (cut in nums1), derive j = half - i.
3) Compare border values and move search range.
4) When valid, return median based on odd/even total length.

Complexity Analysis

Time: O(log(min(m, n))).
Space: O(1).

Common Pitfalls

- Not forcing binary search on the shorter array.
- Off-by-one in half and partition indexes.
- Incorrect edge handling when partition hits array boundaries.

Reference Implementations (Java / Go / C++ / Python / JavaScript)

class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        if (nums1.length > nums2.length) return findMedianSortedArrays(nums2, nums1);

        int m = nums1.length, n = nums2.length;
        int totalLeft = (m + n + 1) / 2;
        int l = 0, r = m;

        while (l <= r) {
            int i = l + (r - l) / 2;
            int j = totalLeft - i;

            int aLeft = (i == 0) ? Integer.MIN_VALUE : nums1[i - 1];
            int aRight = (i == m) ? Integer.MAX_VALUE : nums1[i];
            int bLeft = (j == 0) ? Integer.MIN_VALUE : nums2[j - 1];
            int bRight = (j == n) ? Integer.MAX_VALUE : nums2[j];

            if (aLeft <= bRight && bLeft <= aRight) {
                if (((m + n) & 1) == 1) {
                    return Math.max(aLeft, bLeft);
                }
                return (Math.max(aLeft, bLeft) + Math.min(aRight, bRight)) / 2.0;
            } else if (aLeft > bRight) {
                r = i - 1;
            } else {
                l = i + 1;
            }
        }

        return 0.0;
    }
}
func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
    if len(nums1) > len(nums2) {
        return findMedianSortedArrays(nums2, nums1)
    }

    m, n := len(nums1), len(nums2)
    totalLeft := (m + n + 1) / 2
    l, r := 0, m

    for l <= r {
        i := l + (r-l)/2
        j := totalLeft - i

        aLeft := -1 << 60
        if i > 0 { aLeft = nums1[i-1] }
        aRight := 1 << 60
        if i < m { aRight = nums1[i] }

        bLeft := -1 << 60
        if j > 0 { bLeft = nums2[j-1] }
        bRight := 1 << 60
        if j < n { bRight = nums2[j] }

        if aLeft <= bRight && bLeft <= aRight {
            if (m+n)%2 == 1 {
                if aLeft > bLeft { return float64(aLeft) }
                return float64(bLeft)
            }
            leftMax := aLeft
            if bLeft > leftMax { leftMax = bLeft }
            rightMin := aRight
            if bRight < rightMin { rightMin = bRight }
            return float64(leftMax+rightMin) / 2.0
        } else if aLeft > bRight {
            r = i - 1
        } else {
            l = i + 1
        }
    }

    return 0
}
class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        if (nums1.size() > nums2.size()) return findMedianSortedArrays(nums2, nums1);

        int m = nums1.size(), n = nums2.size();
        int totalLeft = (m + n + 1) / 2;
        int l = 0, r = m;

        while (l <= r) {
            int i = l + (r - l) / 2;
            int j = totalLeft - i;

            int aLeft = (i == 0) ? INT_MIN : nums1[i - 1];
            int aRight = (i == m) ? INT_MAX : nums1[i];
            int bLeft = (j == 0) ? INT_MIN : nums2[j - 1];
            int bRight = (j == n) ? INT_MAX : nums2[j];

            if (aLeft <= bRight && bLeft <= aRight) {
                if ((m + n) % 2 == 1) return max(aLeft, bLeft);
                return (max(aLeft, bLeft) + min(aRight, bRight)) / 2.0;
            } else if (aLeft > bRight) {
                r = i - 1;
            } else {
                l = i + 1;
            }
        }

        return 0.0;
    }
};
class Solution:
    def findMedianSortedArrays(self, nums1: list[int], nums2: list[int]) -> float:
        if len(nums1) > len(nums2):
            return self.findMedianSortedArrays(nums2, nums1)

        m, n = len(nums1), len(nums2)
        total_left = (m + n + 1) // 2
        l, r = 0, m

        while l <= r:
            i = (l + r) // 2
            j = total_left - i

            a_left = float("-inf") if i == 0 else nums1[i - 1]
            a_right = float("inf") if i == m else nums1[i]
            b_left = float("-inf") if j == 0 else nums2[j - 1]
            b_right = float("inf") if j == n else nums2[j]

            if a_left <= b_right and b_left <= a_right:
                if (m + n) % 2 == 1:
                    return max(a_left, b_left)
                return (max(a_left, b_left) + min(a_right, b_right)) / 2
            elif a_left > b_right:
                r = i - 1
            else:
                l = i + 1

        return 0.0
var findMedianSortedArrays = function(nums1, nums2) {
  if (nums1.length > nums2.length) return findMedianSortedArrays(nums2, nums1);

  const m = nums1.length, n = nums2.length;
  const totalLeft = Math.floor((m + n + 1) / 2);
  let l = 0, r = m;

  while (l <= r) {
    const i = Math.floor((l + r) / 2);
    const j = totalLeft - i;

    const aLeft = i === 0 ? -Infinity : nums1[i - 1];
    const aRight = i === m ? Infinity : nums1[i];
    const bLeft = j === 0 ? -Infinity : nums2[j - 1];
    const bRight = j === n ? Infinity : nums2[j];

    if (aLeft <= bRight && bLeft <= aRight) {
      if ((m + n) % 2 === 1) return Math.max(aLeft, bLeft);
      return (Math.max(aLeft, bLeft) + Math.min(aRight, bRight)) / 2;
    } else if (aLeft > bRight) {
      r = i - 1;
    } else {
      l = i + 1;
    }
  }

  return 0;
};

中文

题目概述

给定两个有序数组 nums1nums2,要求在 O(log(m+n)) 时间内求它们合并后的中位数。

核心思路

在更短数组上做二分,寻找一个“切分点”使得左半部分元素总数固定,且左侧最大值不大于右侧最小值。满足条件时即可直接计算中位数。

暴力解法与不足

直接合并数组虽然好写,但时间是 O(m+n),不满足题目对对数复杂度的要求。

最优算法步骤

1)保证在短数组上二分。
2)根据 i 推导另一数组切分点 j
3)比较四个边界值并调整二分区间。
4)命中合法切分后按总长度奇偶返回答案。

复杂度分析

时间复杂度:O(log(min(m,n)))
空间复杂度:O(1)

常见陷阱

- 未在短数组上二分,导致边界处理复杂且容易越界。
- 切分数量公式写错(尤其是奇偶长度)。
- 两端使用无穷边界时漏判。

多语言参考实现(Java / Go / C++ / Python / JavaScript)

class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        if (nums1.length > nums2.length) return findMedianSortedArrays(nums2, nums1);

        int m = nums1.length, n = nums2.length;
        int totalLeft = (m + n + 1) / 2;
        int l = 0, r = m;

        while (l <= r) {
            int i = l + (r - l) / 2;
            int j = totalLeft - i;

            int aLeft = (i == 0) ? Integer.MIN_VALUE : nums1[i - 1];
            int aRight = (i == m) ? Integer.MAX_VALUE : nums1[i];
            int bLeft = (j == 0) ? Integer.MIN_VALUE : nums2[j - 1];
            int bRight = (j == n) ? Integer.MAX_VALUE : nums2[j];

            if (aLeft <= bRight && bLeft <= aRight) {
                if (((m + n) & 1) == 1) {
                    return Math.max(aLeft, bLeft);
                }
                return (Math.max(aLeft, bLeft) + Math.min(aRight, bRight)) / 2.0;
            } else if (aLeft > bRight) {
                r = i - 1;
            } else {
                l = i + 1;
            }
        }

        return 0.0;
    }
}
func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
    if len(nums1) > len(nums2) {
        return findMedianSortedArrays(nums2, nums1)
    }

    m, n := len(nums1), len(nums2)
    totalLeft := (m + n + 1) / 2
    l, r := 0, m

    for l <= r {
        i := l + (r-l)/2
        j := totalLeft - i

        aLeft := -1 << 60
        if i > 0 { aLeft = nums1[i-1] }
        aRight := 1 << 60
        if i < m { aRight = nums1[i] }

        bLeft := -1 << 60
        if j > 0 { bLeft = nums2[j-1] }
        bRight := 1 << 60
        if j < n { bRight = nums2[j] }

        if aLeft <= bRight && bLeft <= aRight {
            if (m+n)%2 == 1 {
                if aLeft > bLeft { return float64(aLeft) }
                return float64(bLeft)
            }
            leftMax := aLeft
            if bLeft > leftMax { leftMax = bLeft }
            rightMin := aRight
            if bRight < rightMin { rightMin = bRight }
            return float64(leftMax+rightMin) / 2.0
        } else if aLeft > bRight {
            r = i - 1
        } else {
            l = i + 1
        }
    }

    return 0
}
class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        if (nums1.size() > nums2.size()) return findMedianSortedArrays(nums2, nums1);

        int m = nums1.size(), n = nums2.size();
        int totalLeft = (m + n + 1) / 2;
        int l = 0, r = m;

        while (l <= r) {
            int i = l + (r - l) / 2;
            int j = totalLeft - i;

            int aLeft = (i == 0) ? INT_MIN : nums1[i - 1];
            int aRight = (i == m) ? INT_MAX : nums1[i];
            int bLeft = (j == 0) ? INT_MIN : nums2[j - 1];
            int bRight = (j == n) ? INT_MAX : nums2[j];

            if (aLeft <= bRight && bLeft <= aRight) {
                if ((m + n) % 2 == 1) return max(aLeft, bLeft);
                return (max(aLeft, bLeft) + min(aRight, bRight)) / 2.0;
            } else if (aLeft > bRight) {
                r = i - 1;
            } else {
                l = i + 1;
            }
        }

        return 0.0;
    }
};
class Solution:
    def findMedianSortedArrays(self, nums1: list[int], nums2: list[int]) -> float:
        if len(nums1) > len(nums2):
            return self.findMedianSortedArrays(nums2, nums1)

        m, n = len(nums1), len(nums2)
        total_left = (m + n + 1) // 2
        l, r = 0, m

        while l <= r:
            i = (l + r) // 2
            j = total_left - i

            a_left = float("-inf") if i == 0 else nums1[i - 1]
            a_right = float("inf") if i == m else nums1[i]
            b_left = float("-inf") if j == 0 else nums2[j - 1]
            b_right = float("inf") if j == n else nums2[j]

            if a_left <= b_right and b_left <= a_right:
                if (m + n) % 2 == 1:
                    return max(a_left, b_left)
                return (max(a_left, b_left) + min(a_right, b_right)) / 2
            elif a_left > b_right:
                r = i - 1
            else:
                l = i + 1

        return 0.0
var findMedianSortedArrays = function(nums1, nums2) {
  if (nums1.length > nums2.length) return findMedianSortedArrays(nums2, nums1);

  const m = nums1.length, n = nums2.length;
  const totalLeft = Math.floor((m + n + 1) / 2);
  let l = 0, r = m;

  while (l <= r) {
    const i = Math.floor((l + r) / 2);
    const j = totalLeft - i;

    const aLeft = i === 0 ? -Infinity : nums1[i - 1];
    const aRight = i === m ? Infinity : nums1[i];
    const bLeft = j === 0 ? -Infinity : nums2[j - 1];
    const bRight = j === n ? Infinity : nums2[j];

    if (aLeft <= bRight && bLeft <= aRight) {
      if ((m + n) % 2 === 1) return Math.max(aLeft, bLeft);
      return (Math.max(aLeft, bLeft) + Math.min(aRight, bRight)) / 2;
    } else if (aLeft > bRight) {
      r = i - 1;
    } else {
      l = i + 1;
    }
  }

  return 0;
};

Comments