Here's my attempt on a grayscale image with only numpy based on the solution from here, with some minor tweaks (adding channels) it might suit your needs:
import numpy as np
# Seperate grayscale images to w * h tiles, add padding with zeros if image not scaled
def to_tiles(arr: np.ndarray, tilesize: tuple[int, int]) -> np.ndarray:
def f(x: tuple[int, int]) -> tuple[int, int]:
tmp = list(x)
if tmp[1] > 0:
tmp[0] = tmp[0] + 1
return tuple(tmp)
# # Stride Implementation
# bytelength = np.int8(np.divide(arr.nbytes, arr.size))
assert arr.ndim == 2, "array must be 2d (grayscale) image"
a_h, a_w = arr.shape
h, w = tilesize
assert a_h > h, "tile height is larger than arr height"
assert a_w > w, "tile width is larger than arr width"
row, row_r = f(np.divmod(a_h, h))
col, col_r = f(np.divmod(a_w, w))
arr = np.pad(
arr,
[
(
np.int8(np.ceil(np.divide(h-row_r, 2))) if row_r != 0 else 0,
np.int8(np.floor(np.divide(h-row_r, 2))) if row_r != 0 else 0,
),
(
np.int8(np.ceil(np.divide(w-col_r, 2))) if col_r != 0 else 0,
np.int8(np.floor(np.divide(w-col_r, 2))) if col_r != 0 else 0,
),
],
"constant",
constant_values=(0),
)
# # Stride Implementation
# arr = np.lib.stride_tricks.as_strided(
# arr, shape=(row, col, h, w), strides=(h*a_w*bytelength, w*bytelength, a_w*bytelength, bytelength)
# )
arr = arr.reshape(row, h, col, w).swapaxes(1, 2)
arr = arr.reshape(-1, h, w)
return arr
Here's an example of the result. Image from FUNSD dataset.