LeetCode Median of Two Sorted Arrays

LeetCode Median of Two Sorted Arrays
There are two sorted arrays nums1 and nums2 of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).


这道题很有意思,虽然本科的时候就做过练习题,用二分查找递归的减少数据规模,但是当时只写过伪代码,这次真正写代码提交的时候,发现有很多BUG。

这个题的思路有3种:1)两个数组放一起排序,取中位数O((m+n)lg(m+n));2)用归并排序的方式把他们Merge到一块,取中间数O(m+n),其实边Merge边数数,Merge到(m+n)/2个数的时候就已经是中位数了;3)二分的思路,比较A,B数组的中位数,根据大小关系选择在前半段或后半段继续递归。

我开始尝试了二分的思路,很多测试样例不能通过,要考虑奇偶性等问题也导致代码不够优雅。查看网络发现可以把这一题转换为从两个已排序数组A,B中取第k小的元素的问题,如果是m+n是奇数,则取中间那个数,偶数则取中间两个数求平均。

关于从两个已排序数组A,B中取第k小的元素的问题,这篇文章有较详细的讨论。下面简要讲一下最后一种O(lgm+lgn)的方法。

如果A的某个元素A_i和B的某两个连续元素B_jB_{j-1}有关系B_{j-1}\leq A_i\leq B_j,那么可以确定A_i就是AB合起来的第i+j+1小的元素,因为A_i大于A中前i个元素,A_i又大于B_{j-1},而B_{j-1}又大于B前面j-1个元素,所以在AB合起来的数组中,A_i大于A前面i个元素和B前面j个元素,排在第i+j+1的位置。同理如果A_{i-1}\leq B_j\leq A_i,则B_j是第i+j+1小的元素。

如果我们令k=i+j+1,则可以容易得到第k小的元素。关于i,j的取值问题,理论上i,j可以取任意值,只要满足i+j=k-1就可以了。不过常见的取法是先取i为A的中点,然后取j=k-1-i;或者按A,B元素个数的比例取值,比如上述链接中的代码。

上述链接中给出了如果A,B中不存在相同元素时的代码,如下:

int findKthSmallest(int A[], int m, int B[], int n, int k) {
  assert(m >= 0); assert(n >= 0); assert(k > 0); assert(k <= m+n);
  
  int i = (int)((double)m / (m+n) * (k-1));
  int j = (k-1) - i;

  assert(i >= 0); assert(j >= 0); assert(i <= m); assert(j <= n);
  // invariant: i + j = k-1
  // Note: A[-1] = -INF and A[m] = +INF to maintain invariant
  int Ai_1 = ((i == 0) ? INT_MIN : A[i-1]);
  int Bj_1 = ((j == 0) ? INT_MIN : B[j-1]);
  int Ai   = ((i == m) ? INT_MAX : A[i]);
  int Bj   = ((j == n) ? INT_MAX : B[j]);

  if (Bj_1 < Ai && Ai < Bj)
    return Ai;
  else if (Ai_1 < Bj && Bj < Ai)
    return Bj;

  assert((Ai > Bj && Ai_1 > Bj) || 
         (Ai < Bj && Ai < Bj_1));

  // if none of the cases above, then it is either:
  if (Ai < Bj)
    // exclude Ai and below portion
    // exclude Bj and above portion
    return findKthSmallest(A+i+1, m-i-1, B, j, k-i-1);
  else /* Bj < Ai */
    // exclude Ai and above portion
    // exclude Bj and below portion
    return findKthSmallest(A, i, B+j+1, n-j-1, k-j-1);
}

针对本题C++代码如下:

class Solution {
public:
	double findKthSmallest(vector<int>& nums1, vector<int>& nums2, int k)//find kth smallest
	{
		int m = nums1.size(), n = nums2.size();
		//always assume that m is equal or smaller than n
		if (m > n)
			return findKthSmallest(nums2, nums1, k);
		if (m == 0)
			return nums2[k - 1];
		if (k == 1)
			return min(nums1[0], nums2[0]);
		//divide k into two parts
		int i = (int)((double)m / (m + n)*(k - 1));
		int j = (k - 1) - i;
		int Ai_1 = ((i == 0) ? INT_MIN : nums1[i - 1]);
		int Bj_1 = ((j == 0) ? INT_MIN : nums2[j - 1]);
		int Ai = ((i == m) ? INT_MAX : nums1[i]);
		int Bj = ((j == n) ? INT_MAX : nums2[j]);
		if (Bj_1 < Ai && Ai < Bj)
			return Ai;
		else if (Ai_1 < Bj && Bj < Ai)
			return Bj;
		if (Ai < Bj)
		{
			vector<int> v1(nums1.begin() + i + 1, nums1.end());
			vector<int> v2(nums2.begin(), nums2.begin() + j);
			return findKthSmallest(v1, v2, k - i - 1);
		}
		else if (Ai > Bj)
		{
			vector<int> v1(nums1.begin(), nums1.begin() + i);
			vector<int> v2(nums2.begin() + j + 1, nums2.end());
			return findKthSmallest(v1, v2, k - j - 1);
		}
		else
			return Ai;
	}
	double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
		int total = nums1.size() + nums2.size();
		if (total & 0x01)
			return findKthSmallest(nums1, nums2, total / 2 + 1);
		else
			return (findKthSmallest(nums1, nums2, total / 2) + findKthSmallest(nums1, nums2, total / 2 + 1)) / 2;
	}
};

本代码提交AC,用时52MS。

二刷。

上面的解法真的太复杂了,完全记不住。从两个有序数组中求第K大的数,在http://www.geeksforgeeks.org/k-th-element-two-sorted-arrays/有非常详细的解法,归并O(m+n)的解法就不说了。

O(lgn+lgm)的解法。假设arr1和arr2的中位数分别是arr1[mid1]和arr2[mid2],如果mid1+mid2arr2[mid2]的情况类似。

如果mid1+mid2>k,说明第K大的数比较小,根据上面的分析,下次递归时,可以舍弃掉[start1,mid1], (mid1,end1], [start2,mid2], (mid2,end2]中较大的那块。比如arr1[mid1]>arr2[mid2],则可以舍弃掉(mid2,end2]。

这样每次都可以舍弃掉两个数组中的某个的一半,时间复杂度是O(lgm+lgn),代码在GeeksforGeeks中。

更优的O(lg(k))的方法是,每次我们不是取中值,而是取arr1[k/2]和arr2[k/2],直接根据这两个值的大小关系去分割递归。代码见GeeksforGeeks中最后一个版本的代码,简洁。

针对这一题,求两个有序数组的中位数,如果两个有序数组的总长度len是奇数,则中位数就是两个有序数组中的第len/2大数;如果len是偶数,则中位数是第len/2和len/2的平均数,所以最多需要两次调用findKth就可以得到结果。完整代码如下:

class Solution {
    private:
    int findKth(vector<int>& nums1, int s1, int e1, vector<int>& nums2, int s2, int e2, int k) {
        int len1 = e1 - s1 + 1, len2 = e2 - s2 + 1;
        if(len1 > len2)
            return findKth(nums2, s2, e2, nums1, s1, e1, k);
        if(len1 == 0)
            return nums2[s2 + k - 1];
        if(k == 1)
            return min(nums1[s1], nums2[s2]);
        int i = s1 + min(k / 2, len1) - 1, j = s2 + min(k / 2, len2) - 1;
        if(nums1[i] > nums2[j])
            return findKth(nums1, s1, e1, nums2, j + 1, e2, k - (j - s2 + 1));
        else
            return findKth(nums1, i + 1, e1, nums2, s2, e2, k - (i - s1 + 1));
    }
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int m = nums1.size(), n = nums2.size();
        if((m + n) % 2) {
            return findKth(nums1, 0, m - 1, nums2, 0, n - 1, (m + n) / 2 + 1);
        } else {
            return (findKth(nums1, 0, m - 1, nums2, 0, n - 1, (m + n) / 2) + findKth(nums1, 0, m - 1, nums2, 0, n - 1, (m + n) / 2 + 1)) / 2.0;
        }
    }
};

本代码提交AC,用时65MS。

Leave a Reply

Your email address will not be published. Required fields are marked *