Is there a way how to modify this function (MyFunc) so that it gives the same result, but its derivative is not zero gradient?
from jax import grad
import jax.nn as nn
import numpy as np
def MyFunc(coefs):
a = coefs[0]
b = coefs[1]
c = coefs[2]
if a > b:
return 30.0
elif b > c:
return 20.0
else:
return 10.0
myFuncDeriv = grad (MyFunc)
# prints [0. 0. 0.]
print (myFuncDeriv(np.random.sample(3)))
# prints [0. 0. 0.]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))
EDIT: Similar function which doesn't give zero gradient - but it doesn't return 30/20/10
def MyFunc2(coefs):
a = coefs[0]
b = coefs[1]
c = coefs[2]
if a > b:
return nn.sigmoid(a)*30.0
if b > c:
return nn.sigmoid(b)*20.0
else:
return nn.sigmoid(c)*10.0
myFunc2Deriv = grad (MyFunc2)
# prints [0. 0. 0.45176652]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))
# prints for example [6.1160526 0. 0. ]
print (myFunc2Deriv(np.random.sample(3)))