3

I would just like to use the ForwardDiff.jl functionality to define a function and plot its gradient (evaluated using ForwardDiff.gradient). It seems not be working because the output of ForwardDiff.gradient is this weird Dual type thing, and it's not easily being converted to the desired type (in my case, a 1-D array of Float32s).

using Plots
using ForwardDiff

my_func(x::Array{Float32,1}) = 1f0. / (1f0 .+ exp(3f0 .* x)) # doesn't matter what this is, just a sigmoid function here

grad_f(x::Array{Float32,1}) = ForwardDiff.gradient(my_func, x)

x_values = collect(Float32,0:0.01:10)

plot(x_values,my_func(x_values)); # this works fine

plot!(x_values,grad_f(x_values)); # this throws an error

And this is the error I get:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float64,12})

When I inspect the type of grad_f(x_values), I get this:

Array{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,12},1},1}

Why doesn't that happen in the example on the documentation for ForwardDiff, for instance? See here: https://github.com/JuliaDiff/ForwardDiff.jl

Thanks in advance.

EDIT: After Kristoffer Carlsson's comments: I tried this but it still doesn't work. I don't understand what is so different about what I tried here versus what he suggested:

function g(x::Float32)
    return x / (1f0 + exp(10f0 * (x - 5f0)))
end

function ∂g∂x(x::Float32)
    return ForwardDiff.derivative(g, x)
end

x_vals = collect(Float32,0:0.01:10)
plot(x_vals,g.(x_vals))
plot!(x_vals,∂g∂x.(x_vals))

With the error now being:

no method matching g(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,1})

And this error occurs just when I call ∂g∂x(x), whether or not I'm using the broadcasted version ∂g∂x.(x). I guess it's something to do with the function definition, but I don't see how the way I defined it differs than Kristoffer's version, other than that it's not defined in a single line...This is so confusing.

This should work because according to ForwardDiff's documentation, you just need the types of the inputs to be a sub-type of Real - and Float32 is a sub-type of Real.

EDIT: I realise that now, having read the comments from others, that you need to restrict your functions to be generic enough to accept any inputs of the abstract type Real, which I didn't quite glean from the documentation. Apologies for the confusion.

Conor
  • 691
  • 5
  • 14
  • 1
    "Which makes 0 sense to me, given what is said in the ForwardDiff.jl documentation about the lack of restrictions on which sub-type of Real you use". No, the docs at http://www.juliadiff.org/ForwardDiff.jl/stable/user/limitations/ explicitly say: "The target function must be written generically enough to accept numbers of type `T<:Real` as input" and if you restrict it to `Float32` then that is not true. – Kristoffer Carlsson Nov 13 '20 at 13:58
  • 1
    Ah I didn't realise that's what that meant -- I thought the notation `T<:Real` means that you can restrict to any type T that is a sub-type of `Real` and therefore they can be restricted in that way. My mistake, thanks for clarifying. – Conor Nov 13 '20 at 14:26

1 Answers1

5

You are defining functions on arrays instead of scalars and also restrict the input types too much. Also, for scalar functions you should use ForwardDiff.derivative. Try something like:

using Plots
using ForwardDiff

my_func(x::Real) =  1f0 / (1f0 + exp(3f0 * x))
my_func_derivative(x::Real) = ForwardDiff.derivative(my_func, x)

plot(my_func, xlimits = (0, 10))
plot!(my_func_derivative)

giving:

enter image description here

Kristoffer Carlsson
  • 2,768
  • 9
  • 16
  • Thanks for the tip - however, I tried it in a slightly different way and it didn't work. Do you know why? I will edit the original post to show what I tried. – Conor Nov 13 '20 at 12:43
  • 2
    You might have to read up on how ForwardDiff works -- it calls your function with [dual numbers](https://en.wikipedia.org/wiki/Dual_number) to track the derivative. If you restrict your function to inputs of type `Float32` this will not work. You need to relax it to `Real` since ForwardDiff's `Dual` type is a subtype of `Real`. – fredrikekre Nov 13 '20 at 13:53
  • 2
    I would add that unless you plan to define another method for that function (to behave differently with other types), then you should probably not enforce any input type. The code will be easier to read, just as performant, and will likely combine with other packages with zero effort. – Benoit Pasquier Nov 13 '20 at 13:56