I need to compute the combination of many 3x3 rotation matrices.
Here is a comparison of applying functools.reduce
on matmul
with numpy
and cupy
:
import timeit
from functools import reduce
import numpy as np
import cupy as cp
from pyrr.matrix33 import create_from_axis_rotation
# generate random rotation matrices
axes = np.random.rand(10000, 3)
angles = np.pi * np.random.rand(10000)
rotations = [create_from_axis_rotation(*params) for params in zip(axes, angles)]
# then reduce with matmul
xp = np # numpy
xp_rotations = [xp.asarray(rotation) for rotation in rotations]
timexp = timeit.timeit("reduce(xp.matmul, xp_rotations)", number=10, globals=globals())
print(f"{xp.__name__}: {timexp * 1000:0.3}ms")
xp = cp # cupy
xp_rotations = [xp.asarray(rotation) for rotation in rotations]
timexp = timeit.timeit("reduce(xp.matmul, xp_rotations)", number=10, globals=globals())
print(f"{xp.__name__}: {timexp * 1000:0.3}ms")
On a good machine with a Titan GPU, this gives :
numpy: 1.63e+02ms
cupy: 8.78e+02ms
For some reason the GPU is much slower.
In any case, is there a way to calculate this significantly faster ?
Edit
I found a rather simple solution, that works for all chains of small linear transformations (and can be extended to affine transformations easily).
def reduce_loop(matrices):
""" non-optimized reduce """
mat = matrices[0]
for _mat in matrices[1:]:
mat = mat @ _mat
return mat
def reduce_split(matrices):
""" reduce by multiplying pairs of matrices recursively """
if len(matrices) == 1:
return matrices[0]
neven = (len(matrices) // 2) * 2
reduced = matrices[:neven:2] @ matrices[1:neven:2]
if len(matrices) > neven: # len(matrices) is odd
reduced[-1] = reduced[-1] @ matrices[-1]
return reduce_split(reduced)
time = timeit.timeit("reduce_loop(rotations)", number=10, globals=globals())
print(f"reduce_loop: {time * 1000:0.3}ms")
time = timeit.timeit("reduce_split(rotations)", number=10, globals=globals())
print(f"reduce_split: {time * 1000:0.3}ms")
Giving:
reduce_loop: 2.14e+02ms
reduce_split: 24.5ms
I'm sure it's not optimal, but it uses numpy
's (and probably cupy
's) optimization.