If you want the total memory to be a multiple of 512
then the number of elements in the tensor must be a multiple of 512 // DATA_TYPE_MULTIPLIER
, e.g. 128
in your case. Whatever that number is, it will have a prime factorization of the form 2**n
. The number of elements in the tensor is given by s[0]*s[1]*...*s[d-1]
where s
is a sequence containing the shape of the tensor and d
is an integer, the number of dimensions. The product s[0]*s[1]*...*s[d-1]
also has some prime factorization and it is a multiple of 2**n
if and only if it contains these prime factors. I.e. the task is to pad the individual dimensions s[i]
such that the resulting prime factorization of the product s[0]*s[1]*...*s[d-1]
contains 2**n
.
If the goal is to reach a minimum possible size of the padded tensor, then one can simply iterate through all multiples of the given target number of elements to find the first one that can be satisfied by padding (increasing) the individual dimensions of the tensor (1). A dimension must be increased as long as it contains at least one prime factor that is not contained in the target multiple size. After all dimensions have been increased such that their prime factors are contained in the target multiple size, one can check the resulting size of the candidate shape: if it matches the target multiple size we are done; if its prime factors are a strict subset of the target multiple prime factors, we can add the missing prime factors to any of the dimensions (e.g. the first); otherwise, we can use the excess prime factors to store the candidate shape for a future (larger) multiplier. The first such future multiplier then marks an upper boundary for the iteration over all possible multipliers, i.e. the algorithm will terminate. However, if the candidate shape (after adjusting all the dimensions) has an excess of prime factors w.r.t. the target multiple size as well as misses some other prime factors, the only way is to iterate over all possible padded shapes with size bound by the target multiple size.
The following is an example implementation:
from collections import Counter
import itertools as it
import math
from typing import Iterator, Sequence
def pad(shape: Sequence[int], target: int) -> tuple[int,...]:
"""Pad the given `shape` such that the total number of elements
is a multiple of the given `target`.
"""
size = math.prod(shape)
if size % target == 0:
return tuple(shape)
target_prime_factors = get_prime_factors(target)
solutions: dict[int, tuple[int,...]] = {} # maps `target` multipliers to corresponding padded shapes
for multiplier in it.count(math.ceil(size / target)):
if multiplier in solutions:
return solutions[multiplier]
prime_factors = [*get_prime_factors(multiplier), *target_prime_factors]
def good(x):
return all(f in prime_factors for f in get_prime_factors(x))
candidate = list(shape)
for i, x in enumerate(candidate):
while not good(x):
x += 1
candidate[i] = x
if math.prod(candidate) == multiplier*target:
return tuple(candidate)
candidate_prime_factor_counts = Counter(f for x in candidate for f in get_prime_factors(x))
target_prime_factor_counts = Counter(prime_factors)
missing = target_prime_factor_counts - candidate_prime_factor_counts
excess = candidate_prime_factor_counts - target_prime_factor_counts
if not excess:
return (
candidate[0] * math.prod(k**v for k, v in missing.items()),
*candidate[1:],
)
elif not missing:
solutions[multiplier * math.prod(k**v for k, v in excess.items())] = tuple(candidate)
else:
for padded_shape in generate_all_padded_shapes(shape, bound=multiplier*target):
padded_size = math.prod(padded_shape)
if padded_size == multiplier*target:
return padded_shape
elif padded_size % target == 0:
solutions[padded_size // target] = padded_shape
def generate_all_padded_shapes(shape: Sequence[int], *, bound: int) -> Iterator[tuple[int,...]]:
head, *tail = shape
if bound % head == 0:
max_value = bound // math.prod(tail)
else:
max_value = math.floor(bound / math.prod(tail))
for x in range(head, max_value+1):
if tail:
yield from ((x, *other) for other in generate_all_padded_shapes(tail, bound=math.floor(bound/x)))
else:
yield (x,)
def get_prime_factors(n: int) -> list[int]:
"""From: https://stackoverflow.com/a/16996439/3767239
Replace with your favorite prime factorization method.
"""
primfac = []
d = 2
while d*d <= n:
while (n % d) == 0:
primfac.append(d) # supposing you want multiple factors repeated
n //= d
d += 1
if n > 1:
primfac.append(n)
return primfac
Here are a few examples:
pad((16, 1, 1), 128) = (128, 1, 1)
pad((16, 51, 1, 4), 128) = (16, 52, 1, 4)
pad((80, 240, 1, 1), 128) = (80, 240, 1, 1)
pad((3, 5, 7, 11), 128) = (3, 5, 8, 16)
pad((3, 3, 3, 1), 128) = (8, 4, 4, 1)
pad((7, 7, 7, 7), 128) = (7, 8, 8, 8)
pad((9, 9, 9, 9), 128) = (10, 10, 10, 16)
Footnotes:
(1) In fact, we need to find the roots of the polynomial (s[0]+x[0])*(s[1]+x[1])*...*(s[d-1]+x[d-1]) - multiple*target
for x[i] >= 0
over the domain of integers. However, I am not aware of any algorithm to solve this problem.