1

Is there a way how to modify this function (MyFunc) so that it gives the same result, but its derivative is not zero gradient?

from jax import grad
import jax.nn as nn
import numpy as np

def MyFunc(coefs):
   a = coefs[0]
   b = coefs[1]
   c = coefs[2]
   
   if a > b:
      return 30.0
   elif b > c:
      return 20.0
   else:
      return 10.0   
   
myFuncDeriv = grad (MyFunc)   

# prints [0. 0. 0.]
print (myFuncDeriv(np.random.sample(3)))
# prints [0. 0. 0.]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))

EDIT: Similar function which doesn't give zero gradient - but it doesn't return 30/20/10

def MyFunc2(coefs):
    a = coefs[0]
    b = coefs[1]
    c = coefs[2]
    if a > b:
        return nn.sigmoid(a)*30.0
    if b > c:
        return nn.sigmoid(b)*20.0
    else:
        return nn.sigmoid(c)*10.0


myFunc2Deriv = grad (MyFunc2)   

# prints [0.         0.         0.45176652]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))
# prints for example [6.1160526 0.        0.       ]
print (myFunc2Deriv(np.random.sample(3)))

pepazdepa
  • 117
  • 8
  • Please add an example with non-random arguments and the expected non-zero gradient for these values. – Michael Szczesny Aug 27 '22 at 08:59
  • @MichaelSzczesny OK, I tried to modify my original post. But expected non-zero gradient: I have no idea what that should be, I don't care what the gradients are - except I don't want it to be zero. – pepazdepa Aug 27 '22 at 14:36

1 Answers1

1

The gradient of your function is zero because this is the correct result for the gradient as your function is defined. For more information on this phenomenon, see FAQ: Why are gradients zero for functions based on sort order?

If you want a sort-based function with non-zero gradients, you can achieve this by replacing your step-wise function with a smooth approximation. The sigmoid version you included in your question seems like a reasonable approach for this approximation.

But note that the answer to your exact question – how to make a function that produces the same output but has nonzero gradients – is impossible, because a function returning the same outputs as yours for all inputs has a zero gradient by definition.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thank you for your reply. If I understand it correctly - the ideal approximation (better than sigmoid) would be if I had a function which returns y=1 for any value of x, except for x very close to zero, e.g. 0.00000001 and -000000001, in which case it returns y=0.999999999 ? – pepazdepa Aug 29 '22 at 18:52