3

I would like to backprop with identity (straight through estimator) when I quantize a tensor in tf2.0 eager mode. I am able to perform the forward pass, but encounter ValueError: No gradients provided for any variable.. error when trying to apply gradients for my model for backprop.

I tried to use custom gradients. As a minimal example, consider the following code:

import tensorflow as tf
@tf.custom_gradient
def quantize(x):
    x = tf.cast(x, dtype=tf.uint8)
    def grad(dy):
        return dy
    return x, grad

@tf.custom_gradient
def dequantize(x):
  x = tf.cast(x, dtype=tf.float32)
  def grad(dy):
       return dy
  return x, grad

x = tf.ones([10,10])
with tf.GradientTape() as g:
  g.watch(x)
  y = dequantize(quantize(x))*2
dy_dx = g.gradient(y, x) 

print(dy_dx) # outputs None

The gradient is None, which means I won't be able to backprop when I use this kind of functions in a model. How should I fix this? Are there better ways to achieve quantization given that I need something other than tf.quantize()?

Also, the way I used the GradientTape g is according to the official tutorials, but I don't understand why they would use g outside the with scope (last line in the code). Any clarification is appreciated.

Taw
  • 497
  • 3
  • 14
  • 1
    Did you solve the issue? If not, maybe qkeras quantizers can help you: https://github.com/google/qkeras/blob/master/qkeras/quantizers.py – Nicolò Ghielmetti Oct 14 '20 at 10:19
  • 1
    I find the resolution in this post,maybe helpful for you! https://stackoverflow.com/questions/36456436/how-can-i-define-only-the-gradient-for-a-tensorflow-subgraph – J. Fan Dec 22 '20 at 08:00

1 Answers1

0

The problem is that you quantize by converting to integer. Use the floor function instead

@tf.custom_gradient
def quantize(x):
    x = tf.math.floor(x)
    def grad(dy):
        return dy
    return x, grad

x = tf.ones((10,10))
with tf.GradientTape() as g:
  g.watch(x)
  y = quantize(1.1*x+.1)*2

dy_dx = g.gradient(y, x, output_gradients = tf.ones((10,10)))

print(dy_dx)
Taw
  • 497
  • 3
  • 14