So to solve this problem, it helps to solve a slightly different one. We want to know the upper/lower bounds in each row for where the overall k'th cutoff is. Then we can go through, verify that the number of things at or below the lower bounds is < k, the number of things at or below the upper bounds is > k, and there is only one value between them.
I've come up with a strategy for doing a binary search in all rows simultaneously for those bounds. Being a binary search it "should" take O(log(n))
passes. Each pass involves O(m)
work for a total of O(m log(n))
times. I put should in quotes because I don't have a proof that it actually takes O(log(n))
passes. In fact it is possible to be too aggressive in a row, discover from other rows that the pivot chosen was off, and then have to back off. But I believe that it does very little backing off and actually is O(m log(n))
.
The strategy is to keep track in each row of a lower bound, an upper bound, and a mid. Each pass we make a weighted series of ranges to lower, lower to mid, mid to upper, and upper to the end with the weight being the number of things in it and the value being the last in the series. We then find the k'th value (by weight) in that data structure, and use that as a pivot for our binary search in each dimension.
If a pivot winds up out of the range from lower to upper, we correct by widening the interval in the direction that corrects the error.
When we have the correct sequence, we've got an answer.
There are a lot of edge cases, so staring at full code may help.
I also assume that all elements of each row are distinct. If they are not, you can get into endless loops. (Solving that means even more edge cases...)
import random
# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
# This does quickselect for average O(len(pairs)).
# Median of medians is deterministically the same, but a bit slower
pivot = pairs[int(random.random() * len(pairs))][0]
# Which side of our answer is the pivot on?
weight_under_pivot = 0
pivot_weight = 0
for value, weight in pairs:
if value < pivot:
weight_under_pivot += weight
elif value == pivot:
pivot_weight += weight
if weight_under_pivot + pivot_weight < k:
filtered_pairs = []
for pair in pairs:
if pivot < pair[0]:
filtered_pairs.append(pair)
return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
elif k <= weight_under_pivot:
filtered_pairs = []
for pair in pairs:
if pair[0] < pivot:
filtered_pairs.append(pair)
return weighted_kth (k, filtered_pairs)
else:
return pivot
# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
# The strategy is to discover the k'th value, and also discover where
# that would be in each row.
#
# For each row we will track what we think the lower and upper bounds
# are on where it is. Those bounds start as the start and end and
# will do a binary search.
#
# In each pass we will break each row into ranges from start to lower,
# lower to mid, mid to upper, and upper to end. Some ranges may be
# empty. We will then create a weighted list of ranges with the weight
# being the length, and the value being the end of the list. We find
# where the k'th spot is in that list, and use that approximate value
# to refine each range. (There is a chance that a range is wrong, and
# we will have to deal with that.)
#
# We finish when all of the uppers are above our k, all the lowers
# one are below, and the upper/lower gap is more than 1 only when our
# k'th element is in the middle.
# Our data structure is simply [row, lower, upper, bound] for each row.
data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
is_search = True
while is_search:
pairs = []
for row, lower, upper, bound in data:
# Literal edge cases
if 0 == upper:
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
elif lower == bound:
pairs.append((row[lower], lower + 1))
elif lower + 1 == upper: # No mid.
pairs.append((row[lower], lower + 1))
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
else:
mid = (upper + lower) // 2
pairs.append((row[lower], lower + 1))
pairs.append((row[mid], mid - lower))
pairs.append((row[upper], upper - mid))
if upper < bound:
pairs.append((row[bound], bound - upper))
pivot = weighted_kth(k, pairs)
# Now that we have our pivot, we try to adjust our parameters.
# If any adjusts we continue our search.
is_search = False
new_data = []
for row, lower, upper, bound in data:
# First cases where our bounds weren't bounds for our pivot.
# We rebase the interval and either double the range.
# - double the size of the range
# - go halfway to the edge
if 0 < lower and pivot <= row[lower]:
is_search = True
if pivot == row[lower]:
new_data.append((row, lower-1, min(lower+1, bound), bound))
elif upper <= lower:
new_data.append((row, lower-1, lower, bound))
else:
new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
elif upper < bound and row[upper] <= pivot:
is_search = True
if pivot == row[upper]:
new_data.append((row, upper-1, upper+1, bound))
elif lower < upper:
new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
else:
new_data.append((row, upper, upper+1, bound))
elif lower + 1 < upper:
if upper == lower+2 and pivot == row[lower+1]:
new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
else:
# We will split this interval.
is_search = True
mid = (upper + lower) // 2
if row[mid] < pivot:
new_data.append((row, mid, upper, bound))
elif pivot < row[mid] pivot:
new_data.append((row, lower, mid, bound))
else:
# We center our interval on the pivot
new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
else:
# We look like we found where the pivot would be in this row.
new_data.append((row, lower, upper, bound))
data = new_data # And set up the next search
return pivot