I followed this link to create a customized op called mask. The main body of the tensorflow op is
def tf_mask(x, labels, epoch_, name=None): # add "labels" to the input
with ops.name_scope(name, "Mask", [x, labels, epoch_]) as name:
z = py_func(np_mask,
[x, labels, epoch_], # add "labels, epoch_" to the input list
[tf.float32],
name=name,
grad=our_grad)
z = z[0]
z.set_shape(x.get_shape())
return z
which actually pretty much follows the cited link. However , I run into this error:
ValueError: Num gradients 1 generated for op name: "mask/Mask"
op: "PyFunc"
input: "conv2/Relu"
input: "Placeholder_2"
input: "Placeholder_3"
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "PyFuncGrad302636"
}
}
attr {
key: "token"
value {
s: "pyfunc_0"
}
}
do not match num inputs 3
In case needed, this is how I define our_grad
function to calculate gradients.
def our_grad(cus_op, grad):
"""Compute gradients of our custom operation.
Args:
param cus_op: our custom op tf_mask
param grad: the previous gradients before the operation
Returns:
gradient that can be sent down to next layer in back propagation
it's an n-tuple, where n is the number of arguments of the operation
"""
x = cus_op.inputs[0]
labels = cus_op.inputs[1]
epoch_ = cus_op.inputs[2]
n_gr1 = tf_d_mask(x)
n_gr2 = tf_gradient2(x, labels, epoch_)
return tf.multiply(grad, n_gr1) + n_gr2
And the py_func
function (the same as the cited link)
def py_func(func, inp, tout, stateful=True, name=None, grad=None):
"""
I omitted the introduction to parameters that are not of interest
:param func: a numpy function
:param inp: input tensors
:param grad: a tensorflow function to get the gradients (used in bprop, should be able to receive previous
gradients and send gradients down.)
:return: a tensorflow op with a registered bprop method
"""
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1000000))
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, tout, stateful=stateful, name=name)
Really need the community's help!
Thanks!