I have an arbitrary NxM matrix, for example:
1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4
I want to get a list of all 3x3 submatrices in this matrix:
1 2 3 2 3 4 0 1 2
7 8 9 ; 8 9 0 ; ... ; 6 7 8
3 4 5 4 5 6 2 3 4
I can do this with two nested loops:
rows, cols = input_matrix.shape
patches = []
for row in np.arange(0, rows - 3):
for col in np.arange(0, cols - 3):
patches.append(input_matrix[row:row+3, col:col+3])
But for a large input matrix, this is slow. Is there a way to do this faster with numpy?
I've looked at np.split
, but that gives me non-overlapping sub-matrices, whereas I want all possible submatrices, regardless of overlap.