I want to convert an RGB image to one with a single channel, whose value is an integer index from a palette (which has already been extracted).
An example:
import tensorflow as tf
# image shape (height=2, width=2, channels=3)
image = tf.constant([
[
[1., 1., 1.], [1., 0., 0.]
],
[
[0., 0., 1.], [1., 0., 0.]
]
])
# palette is a tensor with the extracted colors
# palette shape (num_colors_in_palette, 3)
palette = tf.constant([
[1., 0., 0.],
[0., 0., 1.],
[1., 1., 1.]
])
indexed_image = rgb_to_indexed(image, palette)
# desired result: [[2, 0], [1, 0]]
# result shape (height, width)
I can imagine a few ways to implement rgb_to_indexed(image, palette)
in pure python, but I'm having trouble finding out how to implement it the Tensorflow way (using @tf.funtion for AutoGraph and avoiding for loops), using only (or mostly) vectorized operations.
Edit 1: showing sample python/numpy code
If the code need not use Tensorflow, a non-vectorized implementation could be:
import numpy as np
def rgb_to_indexed(image, palette):
result = np.ndarray(shape=[image.shape[0], image.shape[1]])
for i, row in enumerate(image):
for j, color in enumerate(row):
index, = np.where(np.all(palette == color, axis=1))
result[i, j] = index
return result
indexed_image = rgb_to_indexed(image.numpy(), palette.numpy())
# indexed_image is [[2, 0], [1, 0]]