LintCode/Median of Two Sorted Arrays

Problem Summary

Given two sorted arrays A and B of size m and n respectively, find the median of the two sorted arrays.

Solution

Apparently, we can merge A and B to get a new sorted array C. Then the answer is the median of C.
The time and space complexities are both O(M+N), so let us think about better solutions.

Suppose M+N is odd, then finding the median of C is actually finding the ((M+N+1)/2)th number of C. So we need a way to find the kth number of C in less than linear time, which leads us to think about algorithms with O(log(M+N)) time complexity, such as binary search.

The point of binary search is to cut the search range by half each time, I think. But can we do the same thing in this problem? To answer this question, let us look at A[k/2-1] and B[k/2-1], which we use M1 and M2 to denote respectively.

In array A, there are k/2-1 numbers smaller than M1, and m-k/2 numbers greater than it. The situation is similar for B and M2. So, if M1 and M2 are equal, there are exactly k-1 numbers in C smaller or equal to M1. Thus M1 is the kth number of C. If M1 < M2, there are less than k-2 numbers in C smaller than M1; therefore, we know M1 and any number smaller than it are not the target we want. So it is safe for us to throw the first k/2 numbers of A out of consideration. The approach is similar when M1 > M2.

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
/*
* @param A: An integer array
* @param B: An integer array
* @return: a double whose format is *.5 or *.0
*/
double findMedianSortedArrays(vector<int> &A, vector<int> &B) {
int tot = A.size() + B.size();
if (tot == 0)
return 0;
int m = A.size(), n = B.size();
if (tot%2)
return find_kth(0,m,A,0,n,B,tot/2+1);
else
return 0.5 * (find_kth(0,m,A,0,n,B,tot/2) + find_kth(0,m,A,0,n,B,tot/2+1));
}
int find_kth(int base1,int len1, vector<int> &A, int base2, int len2, vector<int> &B, int k)
{
if (len1 > len2)
return find_kth(base2,len2,B,base1,len1,A,k);
if (len1 == 0)
return B[base2 + k-1];
if (k == 1)
return min(A[base1],B[base2]);
int half1 = min(k/2, len1);
int half2 = k - half1;
if (A[base1 + half1 - 1] == B[base2 + half2 - 1])
return A[base1 + half1 - 1];
else
if (A[base1 + half1 - 1] < B[base2 + half2 - 1])
return find_kth(base1 + half1, len1 - half1, A, base2,len2,B, k - half1);
else
return find_kth(base1,len1,A, base2 + half2, len2 - half2, B, k - half2);
}
};