1

I have a tensor with shape [None, 128, 128, n_classes]. This is a one-hot tensor, where the last index contains the categorical values for multiple classes (there are n_classes in total). In practice, the last channel has binary values that indicate the class of each pixel: e.g. when a pixel has 1 in the channel C it means it belongs to the class C; this pixel will have 0 elsewhere.

Now, I wish to convert this one-hot tensor to an RGB image, that I want to plot on Tensorboard. Every class has to be associated with a different colour so that it is easier to interpret.

Any idea on how to do that?

Thanks, G.


Edit 2:

Solution added in the answers.


Edit 1:

My current implementation (not working):

def from_one_hot_to_rgb(incoming, palette=None):
    """ Assign a different color to each class in the input tensor """
    if palette is None:
        palette = {
            0: (0, 0, 0),
            1: (31, 12, 33),
            2: (13, 26, 33),
            3: (21, 76, 22),
            4: (22, 54, 66)
        }

    def _colorize(value):
        return palette[value]

    # from one-hot to grayscale:
    cmap = tf.expand_dims(tf.argmax(incoming, axis=-1), axis=-1)

    # flatten input tensor (pixels on the first axis):
    B, W, H, C = get_shape(camp)  # this returns batch_size, 128, 128, 5
    cmap_flat = tf.reshape(cmap, shape=[B * W * H, C])

    # assign a different color to each class:
    cmap = tf.map_fn(lambda pixel:
                     tf.py_func(_colorize, inp=[pixel], Tout=tf.int64),
                     cmap_flat)

    # back to original shape, but RGB output:
    cmap = tf.reshape(cmap, shape=[B, W, H, 3])

    return tf.cast(cmap, dtype=tf.float32)
gab
  • 792
  • 1
  • 10
  • 36

2 Answers2

1

I would use imshow* or matshow* from matplotlib to create the plot and then use this answer or other answers of the same question to display it in tensor board.

import matplotlib.pyplot as plt

plt.imshow(tf.argmax(imgs[0], axis=-1))

One of upsides of this approach is that you don't have to worry about the class to colors mapping.


to fix the code you already have, first you should note that the argument that is passed to colorize is a numpy array of length 1 instead of an int; which is not hashable so it can't be used for dictionary keys. You can convert it to int type simply like palette[int(value)].

I have changed a few things in your code here and there and tested it on a random batch of size 1 and the final code looks like this:

def from_one_hot_to_rgb(incoming, palette=None):
    """ Assign a different color to each class in the input tensor """
    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    # from one-hot to grayscale:
    B, W, H, _ = incoming.get_shape()   # this returns batch_size, 128, 128, 5
    cmap = tf.reshape(tf.argmax(incoming, axis=-1), [-1, 1])
    cmap = tf.map_fn(lambda value: palette[int(value)], cmap)

    # back to original shape, but RGB output:
    cmap = tf.reshape(cmap, shape=[B, W, H, 3])

    return tf.cast(cmap, dtype=tf.float32)
Mohammad Jafar Mashhadi
  • 4,102
  • 3
  • 29
  • 49
  • thank you for the suggestion. However, I find the solution not very intuitive... I've been working on the code and updated the question with my current implementation (not working) – gab Mar 19 '20 at 18:27
  • Thank you for the hints! I had to modify a bit your answer to solve a bug (you cannot just call "tf.map_fn(lambda: palette[int(value)], camp)" as is). I'm trying to make my code run so that I can test if the solution is correct. I will then update the answer :) – gab Mar 20 '20 at 11:59
  • I added the solution to the problem in a new answer if you are interested :) I still upvoted your answer because it was useful to understand one problem in my old version of the code. Thanks a lot for the help! ;) – gab Mar 23 '20 at 09:46
  • @gabriele I'm glad that you found the solution :) Take care, wash your hands frequently – Mohammad Jafar Mashhadi Mar 23 '20 at 22:15
0

Solution 1 (slow)

A possible solution, similar to the initial code is the following. Notice that this can be very slow because of a known problem of TensorFlow tf.map_fn

def from_one_hot_to_rgb_bkup(incoming, palette=None):

    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    # from one-hot to grayscale:
    B, W, H, _ = get_shape(incoming)
    gray = tf.reshape(tf.argmax(incoming, axis=-1, output_type=tf.int32), [-1, 1], name='flatten')

    # assign colors to each class
    rgb = tf.map_fn(lambda pixel:
                    tf.py_func(lambda value: palette[int(value)], inp=[pixel], Tout=tf.int32),
                    gray, name='colorize')

    # back to original shape, but RGB output:
    rgb = tf.reshape(rgb, shape=[B, W, H, 3], name='back_to_rgb')

    return tf.cast(rgb, dtype=tf.float32)

Solution 2 (fast)

Based on this answer, a much faster solution can be using tf.gather:

def from_one_hot_to_rgb_bkup(incoming, palette=None):

    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    _, W, H, _ = get_shape(incoming)
    palette = tf.constant(palette, dtype=tf.uint8)
    class_indexes = tf.argmax(incoming, axis=-1)

    class_indexes = tf.reshape(class_indexes, [-1])
    color_image = tf.gather(palette, class_indexes)
    color_image = tf.reshape(color_image, [-1, W, H, 3])

    color_image = tf.cast(color_image, dtype=tf.float32)
gab
  • 792
  • 1
  • 10
  • 36