I don't think vmap going to work with tuple of scalars. What you need is to put indices into array and vmap over it.
I am not sure if this solution satisfies you because we have to get rid of empty indices pairs ().
idxs_pairs = jnp.array([[7,8],[7,9]]) # put the indices pairs into array
@jit
def distance(X, Y):
"""Compute distance between two matrices X and Y.
Args:
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.mean(jnp.abs(X - Y))
@jit
def compute_metrics(idxs, X, Y):
return distance(X[:,idxs], Y[:,idxs])
vmap(compute_metrics, in_axes=(0, None, None))(idxs_pairs, X, Y)
You can also jit everything:
jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
Update 19/05/2023:
The question is how to make it more general - to have variable number of indices. The problem here is that JAX needs static shapes of input and output, therefore we need some tricks how to deal with this. The most obvious trick in such cases is to use jnp.where function to handle this conditional behavior. The other choice is jax.lax.cond. Therefore as before, we put indices into an array but this time we set -1 as a special flag indicating this is empty fill in the matrix (this is like zero-padding but with -1 instead of 0s). Because arrays have static shape, the number of columns in idxs_pairs should be the max number of pairs.
For example:
# 7, 8, -1 -> we only use indices: 7, 8
# 7, 9, -1 -> we only use indices: 7, 9
# 7, 5, 6 -> we use indices: 7, 5, 6
# 1, -1, -1 -> we use only index: 1
idxs_pairs = jnp.array([[7, 8, -1], [7, 9, -1], [7, 5, 6], [1, -1, -1]]) # put the indices pairs into array
We now redefine our new functions:
def distance_vectors(idx, X, Y):
"""Compute distance between two vectors of matrices X and Y.
Args:
idx (jax.numpy.ndarray): scalar indicating index of column
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.abs(X[:,idx] - Y[:,idx])
def compute_metrics(idxs, X, Y):
distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
distances = distances.T * jnp.where(idxs >= 0, 1, 0)
n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
output = 1/n_of_actual_indices * 1/X.shape[0] * jnp.sum(distances)
return output
output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
I am not sure this is the most optimal way of doing it - it depends if XLA compiler can catch that we set distance of zero for -1 indices, but I am not an XLA expert. I will later provide another solution based on jax.lax.cond which can be faster, so we can benchmark.
Update: 22/05/2023
In case of jax.lax.cond the implementation can look like this:
def distance_vectors(idx, X, Y):
"""Compute distance between two vectors of matrices X and Y.
Args:
idx (jax.numpy.ndarray): scalar indicating index of column
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return lax.cond(idx >= 0, lambda: jnp.abs(X[:,idx] - Y[:,idx]), lambda: jnp.zeros_like(X[:,idx]))
def compute_metrics(idxs, X, Y):
distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
output = 1/n_of_actual_indices * 1/X.shape[0] * jnp.sum(distances)
return output
output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
I tested it and execution times are the same as for jnp.where case.