2

I need to pass YUV images to my TF model. tf.image.yuv_to_rgb seems like the way to go, however it expects the YUV input to be of size [H,W,3] with Y normalized to [0,1] and UV to [-0.5,0.5] https://www.tensorflow.org/api_docs/python/tf/image/yuv_to_rgb.

I'm new to the YUV format but this seems strange & non-conventional to me based on other YUV-RGB conversion scripts I've found on the web:

My question is: what is the correct usage of tf.image.yuv_to_rgb? Is there a better documentation or a pre-processing method provided by TF? I've tried modifying the above solutions to fit the range expected by tf.image.yuv_to_rgb but was getting distorted outputs. Below is my code (ipynb) and attached are my outputs:

outputs

import numpy as np
import tensorflow as tf
import imageio
import os
import argparse
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from time import perf_counter

w,h = 520,390
px = w*h

'''
Seems like there are 2 conventions for YUV conversion: 
- ITU-R BT.601 version (SDTV)
- ITU-R BT.709 version (HDTV)
For both of these I performed conversion using numpy mat-mul open source code. 
These look somewhat OK, but I'm clearly losing some chrominance data -- not exactly sure what I'm doing wrong.

These methods do NOT work for tf.image.yuv_to_rgb, as it expects Y:[0,1], UV:[-0.5,0.5] (while YUV is uint8). 
For the TF method I modified the pre-processing of the YUV data to match the range expected by tf.image.yuv_to_rgb.
I was able to visually match the outputs of the above NP methods -- but clearly there's some color loss going on.

One more interesting observation: NP approaches seem faster than TF. Not sure if this changes when tf.image.yuv_to_rgb
runs as part of the tflite model itself
'''

# http://maxsharabayko.blogspot.com/2016/01/fast-yuv-to-rgb-conversion-in-python-3.html
def YUV2RGB( yuv ):
    A = np.array([[1.164,  0.000,  1.793],
                  [1.164, -0.213, -0.533],
                  [1.164,  2.112,  0.000]])
    rgb = np.dot(yuv, A.T).clip(0, 255).astype('uint8')
    return rgb

# https://picamera.readthedocs.io/en/release-1.9/recipes2.html#unencoded-image-capture-yuv-format
def YUV2RGB_sdtv(yuv):
    #              Y       U       V
    M = np.array([[1.164,  0.000,  1.596],    # R
                  [1.164, -0.392, -0.813],    # G
                  [1.164,  2.017,  0.000]])   # B
    # Take the dot product with the matrix to produce RGB output, clamp the
    # results to byte range and convert to bytes
    RGB = YUV.dot(M.T).clip(0, 255).astype(np.uint8)
    return RGB


# https://stackoverflow.com/questions/53467655/import-yuv-as-a-byte-array
def IMG2YUV(img):
    # Read entire file into YUV
    YUV = np.fromfile(img,dtype='uint8')

    # Take first h x w samples and reshape as Y channel
    Y = YUV[0:w*h].reshape(h,w)

    # Take next px/4 samples as U
    U = YUV[px:(px*5)//4].reshape(h//2,w//2)

    # Take next px/4 samples as V
    V = YUV[(px*5)//4:(px*6)//4].reshape(h//2,w//2)

    # Undo subsampling of U and V by doubling h and w
    U = U.repeat(2, axis=0).repeat(2, axis=1)
    V = V.repeat(2, axis=0).repeat(2, axis=1)

    YUV = np.dstack((Y, U, V))[:h, :w, :].astype(np.float32)
    YUV[:,:, 0] = YUV[:,:, 0].clip(16, 235).astype(YUV.dtype) - 16
    YUV[:,:,1:] = YUV[:,:,1:].clip(16, 240).astype(YUV.dtype) - 128
    return YUV

# https://raspberrypi.stackexchange.com/questions/28033/reading-frames-of-uncompressed-yuv-video-file
def IMG2YUV_sdtv(img):
    Y = np.fromfile(img, dtype=np.uint8, count=w*h).\
        reshape((h, w))

    # Load the UV (chrominance) data from the stream, and double its size
    U = np.fromfile(img, dtype=np.uint8, count=(w//2)*(h//2)).\
            reshape((h//2, w//2)).\
            repeat(2, axis=0).repeat(2, axis=1)
    V = np.fromfile(img, dtype=np.uint8, count=(w//2)*(h//2)).\
            reshape((h//2, w//2)).\
            repeat(2, axis=0).repeat(2, axis=1)

    # Stack the YUV channels together, crop the actual resolution, convert to
    # floating point for later calculations, and apply the standard biases
    YUV = np.dstack((Y, U, V))[:h, :w, :].astype(np.float32)

    YUV[:, :, 0]  = YUV[:, :, 0]  - 16   # Offset Y by 16
    YUV[:, :, 1:] = YUV[:, :, 1:] - 128  # Offset UV by 128

    return YUV

# Best guess at normalization, based on https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/image/yuv_to_rgb
def IMG2YUV_tf(img):
    # Read entire file into YUV
    YUV = np.fromfile(img,dtype='uint8')

    # Take first h x w samples and reshape as Y channel
    Y = YUV[0:w*h].reshape(h,w)

    # Take next px/4 samples as U
    U = YUV[px:(px*5)//4].reshape(h//2,w//2)

    # Take next px/4 samples as V
    V = YUV[(px*5)//4:(px*6)//4].reshape(h//2,w//2)

    # Undo subsampling of U and V by doubling h and w
    U = U.repeat(2, axis=0).repeat(2, axis=1)
    V = V.repeat(2, axis=0).repeat(2, axis=1)

    YUV = np.dstack((Y, U, V))[:h, :w, :].astype(np.float32)

    YUV[:, :, 0]  = YUV[:, :, 0]  / 255        # [0,1]
    YUV[:, :, 1:] = YUV[:, :, 1:] / 255 - 0.5  # [-0.5,0.5] 

    return YUV
imgs = ["colors"]

img = imgs[0]
print("\nImage: {}".format(img))
YUV = IMG2YUV(img+".yuv")
YUV_sdtv = IMG2YUV_sdtv(img+".yuv")
YUV_tf = IMG2YUV_tf(img+".yuv")
with tf.Session() as sess: 
    s_tf = perf_counter()
    RGB_tf = tf.image.yuv_to_rgb(YUV_tf)
    RGB_tf = RGB_tf.eval()
    e_tf = perf_counter()

    s_tf_sdtv = perf_counter()
    RGB_tf_sdtv = tf.image.yuv_to_rgb(YUV_sdtv)
    RGB_tf_sdtv = RGB_tf_sdtv.eval()
    e_tf_sdtv = perf_counter()

s_np = perf_counter()
RGB_np = YUV2RGB(YUV)
e_np = perf_counter()

s_np_sdtv = perf_counter()
RGB_np_sdtv = YUV2RGB_sdtv(YUV_sdtv)
e_np_sdtv = perf_counter()

PNG = np.float32(imageio.imread(DATA_DIR + img + '.png'))
PNG_norm = (2.0 / 255.0) * PNG - 1.0

print("TF (using SDTV):")
print("{} seconds".format(e_tf_sdtv - s_tf_sdtv))
plt.imshow(RGB_tf_sdtv)
plt.show()

print("TF (normalized to Y [0,1], UV [-0.5,0.5]:")
print("{} seconds".format(e_tf - s_tf))
plt.imshow(RGB_tf)
plt.show()

print("np:")
print("{} seconds".format(e_np - s_np))
plt.imshow(RGB_np)
plt.show()

print("np_sdtv:")
print("{} seconds".format(e_np_sdtv - s_np_sdtv))
plt.imshow(RGB_np_sdtv)
plt.show()

print("original:")
plt.imshow(PNG.astype('uint8'))
plt.show()
Alex
  • 301
  • 3
  • 7

0 Answers0