Consider the easy case when all the values are distinct:
A = np.arange(25).reshape(5,5)
ans = [1,3,4]
B = A[np.ix_(ans, ans)]
In [287]: A
Out[287]:
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
In [288]: B
Out[288]:
array([[ 6, 8, 9],
[16, 18, 19],
[21, 23, 24]])
If we test the first row of B with each row of A, we will eventually come to the
comparison of [6, 8, 9]
with [5, 6, 7, 8, 9]
from which we can glean the
candidate solution of indices [1, 3, 4]
.
We can generate a set of all possible candidate solutions by pairing the first
row of B with each row of A.
If there is only one candidate, then we are done, since we are given that B is a
submatrix of A and therefore there is always a solution.
If there is more than one candidate, then we can do the same thing with the
second row of B, and take the intersection of the candidate solutions -- After
all, a solution must be a solution for each and every row of B.
Thus we can loop through the rows of B and short-circuit once we find there
is only one candidate. Again, we are assuming that B is always a submatrix of A.
The find_idx
function below implements the idea described above:
import itertools as IT
import numpy as np
def find_idx_1d(rowA, rowB):
result = []
if np.in1d(rowB, rowA).all():
result = [tuple(sorted(idx))
for idx in IT.product(*[np.where(rowA==b)[0] for b in rowB])]
return result
def find_idx(A, B):
candidates = set([idx for row in A for idx in find_idx_1d(row, B[0])])
for Bi in B[1:]:
if len(candidates) == 1:
# stop when there is a unique candidate
return candidates.pop()
new = [idx for row in A for idx in find_idx_1d(row, Bi)]
candidates = candidates.intersection(new)
if candidates:
return candidates.pop()
raise ValueError('no solution found')
Correctness: The two solutions you've proposed may not always return the correct result, particularly when there are repeated values. For example,
def is_solution(A, B, idx):
return np.allclose(A[np.ix_(idx, idx)], B)
def find_idx_orig(A, B):
index = []
for j in range(len(B)):
k = 0
while k<len(A) and set(np.intersect1d(B[j],A[k])) != set(B[j]):
k+=1
index.append(k)
return index
def find_idx_diag(A, B):
index = []
a = np.diag(A)
b = np.diag(B)
for j in range(len(b)):
k = 0
while a[j+k] != b[j] and k<len(A):
k+=1
index.append(k+j)
return index
def counterexample():
"""
Show find_idx_diag, find_idx_orig may not return the correct result
"""
A = np.array([[1,2,0],
[2,1,0],
[0,0,1]])
ans = [0,1]
B = A[np.ix_(ans, ans)]
assert not is_solution(A, B, find_idx_orig(A, B))
assert is_solution(A, B, find_idx(A, B))
A = np.array([[1,2,0],
[2,1,0],
[0,0,1]])
ans = [1,2]
B = A[np.ix_(ans, ans)]
assert not is_solution(A, B, find_idx_diag(A, B))
assert is_solution(A, B, find_idx(A, B))
counterexample()
Benchmark: Ignoring at our peril the issue of correctness, out of curiosity
let's compare these functions on the basis of speed.
def make_AB(n, m):
A = symmetrize(np.random.random((n, n)))
ans = np.sort(np.random.choice(n, m, replace=False))
B = A[np.ix_(ans, ans)]
return A, B
def symmetrize(a):
"http://stackoverflow.com/a/2573982/190597 (EOL)"
return a + a.T - np.diag(a.diagonal())
if __name__ == '__main__':
counterexample()
A, B = make_AB(500, 450)
assert is_solution(A, B, find_idx(A, B))
In [283]: %timeit find_idx(A, B)
10 loops, best of 3: 74 ms per loop
In [284]: %timeit find_idx_orig(A, B)
1 loops, best of 3: 14.5 s per loop
In [285]: %timeit find_idx_diag(A, B)
100 loops, best of 3: 2.93 ms per loop
So find_idx
is much faster than find_idx_orig
, but not as fast as
find_idx_diag
.