3

Summary: Scroll down for reproducible example which should run-from-scratch in Julia if you have the packages specified in the using lines. (Note: the ODE has a complex, re-usable structure which is specified in a Gist which is downloaded/included by the script.)

Background: I have to repeatedly solve a large system of ODEs for different initial conditions vectors. In the example below, it is 127 states/ODEs, but it could easily be 1000-2000. I will have to run these 100s-1000s of times for inference, so speed is essential.

The Puzzle: The short version is that, for the serial functions, the @simd version is much faster than the "plain", non-@simd version. But for the parallel versions, the @simd version is much slower -- plus, in this case, the answer, sum_of_solutions, is variable and wrong.

I have this set up so that Julia is started with JULIA_NUM_THREADS=auto julia, in my case this creates 8 cores for 8 threads. Then, I make sure I never have more than 8 jobs spawned at once.

The different calculation times: (runtime, then sum_of_ODE_solutions)

# Output is (runtime, sum_of_solutions)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (1.1, 8.731365050398926)
# (0.878, 8.731365050398926)
# (0.898, 8.731365050398926)

serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (0.046, 8.731365050398928)
# (0.042, 8.731365050398928)
# (0.046, 8.731365050398928)

parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Faster than serial plain version
# (duration, sum_of_solutions)
# (0.351, 8.731365050398926)
# (0.343, 8.731365050398926)
# (0.366, 8.731365050398926)

parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Dramatically slower than serial simd version, plus wrong sum_of_solutions
# (duration, sum_of_solutions)
# (136.966, 9.61313614002137)
# (141.843, 9.616688089683372)

As you can see, while serial @simd gets the calculation speed down to 0.046 seconds, and while parallel plain is 2.5 times faster than serial plain, when I combine parallelization with the @simd function I get runtimes of 140 seconds, and with variable & wrong answers to boot! Literally the only difference between the two parallelizng functions is using core_op_plain versus core_op_simd for the core ODE solving operation.

It seems like @simd and @spawn must be conflicting somehow? I have the parallelizing function set up to never employ more than the number of CPU threads available. (8 on my machine.)

I am still learning Julia so there is the chance that some smallish change could isolate the @simd calculations and prevent conflicts across threads (if that is what is happening). Any help is very much appreciated!

PS: Reproducible Example. The code below should provide a reproducible example on any Julia session running multiple cores. I also have my versioninfo() etc.:

versioninfo()
notes="""
My setup is:
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin21.4.0)
  CPU: Intel(R) Xeon(R) CPU E5-2697 v2 @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, ivybridge)
"""


# Startup notes
notes="""
# "If $JULIA_NUM_THREADS is set to auto, then the number of threads will be set to the number of CPU threads."
JULIA_NUM_THREADS=auto julia --startup-file=no
Threads.nthreads(): 8 # Number of CPU threads
"""


using LinearAlgebra     # for "I" in: Matrix{Float64}(I, 2, 2)
                                            # https://www.reddit.com/r/Julia/comments/9cfosj/identity_matrix_in_julia_v10/
using Sundials              # for CVODE_BDF
using Statistics            # for mean(), max()
using DataFrames  # for e.g. DataFrame()
using Dates                     # for e.g. DateTime, Dates.now()
using DifferentialEquations # for ODEProblem
using BenchmarkTools    # for @benchmark
using Distributed           # for workers


# Check that you have multiple threads
numthreads = Base.Threads.nthreads()

# Download & include the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
#include("/GitHub/BioGeoJulia.jl/test/model_p_object.jl")
url = "https://gist.githubusercontent.com/nmatzke/ed99ab8f5047794eb25e1fdbd5c43b37/raw/b3e6ddff784bd3521d089642092ba1e3830699c0/model_p_object.jl"
download(url,  "model_p_object.jl")
include("model_p_object.jl")

# Load the ODE functions
url = "https://gist.githubusercontent.com/nmatzke/f116258c78bd43ab7a448f07c4290516/raw/24a210261fd2e090b8ed27bc64a59a1ff9ec62cd/simd_vs_spawn_setup_v2.jl"
download(url,  "simd_vs_spawn_setup_v2.jl")
include("simd_vs_spawn_setup_v2.jl")

#include("/GitHub/BioGeoJulia.jl/test/simd_vs_spawn_setup_v2.jl")
#include("/GitHub/BioGeoJulia.jl/test/simd_vs_spawn_setup_v3.jl")

# Load the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
p_Es_v5 = load_ps_127();



# Set up output object
numstates = 127
number_of_solves = 10

solve_results1 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results1 .= 0.0
solve_results2 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results2 .= 0.0
length(solve_results1)
length(solve_results1[1])
sum(sum.(solve_results1))


# Precalculate the Es for use in the Ds
Es_tspan = (0.0, 60.0)
prob_Es_v7 = DifferentialEquations.ODEProblem(Es_v7_simd_sums, p_Es_v5.uE, Es_tspan, p_Es_v5);
sol_Es_v7 = solve(prob_Es_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=true, 
abstol=1e-12, reltol=1e-9);

p_Ds_v7 = (n=p_Es_v5.n, params=p_Es_v5.params, p_indices=p_Es_v5.p_indices, p_TFs=p_Es_v5.p_TFs, uE=p_Es_v5.uE, terms=p_Es_v5.terms, sol_Es_v5=sol_Es_v7);


# Set up ODE inputs
u = collect(repeat([0.0], numstates));
u[2] = 1.0
du = similar(u)
du .= 0.0
p = p_Ds_v7;
t = 1.0

# ODE functions to integrate (single-step; ODE solvers will run this many many times)
@time Ds_v5_tmp(du,u,p,t)
@time Ds_v5_tmp(du,u,p,t)
@time Ds_v7_simd_sums(du,u,p,t)
@time Ds_v7_simd_sums(du,u,p,t)

#@btime Ds_v5_tmp(du,u,p,t)
# 7.819 ms (15847 allocations: 1.09 MiB)

#@btime Ds_v7_simd_sums(du,u,p,t)
# 155.858 μs (3075 allocations: 68.66 KiB)



tspan = (0.0, 1.0)
prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, p_Ds_v7.uE, tspan, p_Ds_v7);

sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);

# This is the core operation; plain version (no @simd)
function core_op_plain(u, tspan, p_Ds_v7)
    prob_Ds_v5 = DifferentialEquations.ODEProblem(Ds_v5_tmp, u.+0.0, tspan, p_Ds_v7);

    sol_Ds_v5 = solve(prob_Ds_v5, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
    return sol_Ds_v5
end


# This is the core operation; @simd version
function core_op_simd(u, tspan, p_Ds_v7)
    prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, u.+0.0, tspan, p_Ds_v7);

    sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
    return sol_Ds_v7
end

@time core_op_plain(u, tspan, p_Ds_v7);
@time core_op_plain(u, tspan, p_Ds_v7);
@time core_op_simd(u, tspan, p_Ds_v7);
@time core_op_simd(u, tspan, p_Ds_v7);


function serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
    start_time = Dates.now()
    for i in 1:number_of_solves
        # Temporary u
        solve_results1[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results1[i,i] = 1.0
        solve_results1

        sol_Ds_v7 = core_op_plain(solve_results1[i,:], tspan, p_Ds_v7)
        solve_results1[i,:] .=  sol_Ds_v7.u[length(sol_Ds_v7.u)]
    #   print("\n")
    #   print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
    end
    
    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results1))
    return (duration, sum_of_solutions)
end


function serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
    start_time = Dates.now()
    for i in 1:number_of_solves
        # Temporary u
        solve_results1[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results1[i,i] = 1.0
        solve_results1

        sol_Ds_v7 = core_op_simd(solve_results1[i,:], tspan, p_Ds_v7)
        solve_results1[i,:] .=  sol_Ds_v7.u[length(sol_Ds_v7.u)]
    #   print("\n")
    #   print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
    end
    
    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results1))
    return (duration, sum_of_solutions)
end

# Output is (runtime, sum_of_solutions)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (1.1, 8.731365050398926)
# (0.878, 8.731365050398926)
# (0.898, 8.731365050398926)

serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (0.046, 8.731365050398928)
# (0.042, 8.731365050398928)
# (0.046, 8.731365050398928)

using Distributed

function parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
    start_time = Dates.now()
    number_of_threads = Base.Threads.nthreads()
    curr_numthreads = Base.Threads.nthreads()
        
    # Individual ODE solutions will occur over different timeperiods,
    # initial values, and parameters.  We'd just like to load up the 
    # cores for the first jobs in the list, then add jobs as earlier
    # jobs finish.
    tasks = Any[]
    tasks_started_TF = Bool[]
    tasks_fetched_TF = Bool[]
    task_numbers = Any[]
    task_inc = 0
    are_we_done = false
    current_running_tasks = Any[]
    
    # List the tasks
    for i in 1:number_of_solves
        # Temporary u
        solve_results2[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results2[i,i] = 1.0

        task_inc = task_inc + 1
        push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
        push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
        push!(task_numbers, task_inc)
    end
    
    # Total number of tasks
    num_tasks = length(tasks_fetched_TF)

    iteration_number = 0
    while(are_we_done == false)
        iteration_number = iteration_number+1
        
        # Launch tasks when thread (core) is available
        for j in 1:num_tasks
            if (tasks_fetched_TF[j] == false)
                if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
                    # Start a task
                    push!(tasks, Base.Threads.@spawn core_op_plain(solve_results2[j,:], tspan, p_Ds_v7));
                    curr_numthreads = curr_numthreads-1;
                    tasks_started_TF[j] = true;
                    push!(current_running_tasks, task_numbers[j])
                end
            end
        end
        
        # Check for finished tasks
        tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
        if sum(tasks_to_check_TF .== true) > 0
            for k in 1:sum(tasks_to_check_TF)
                if (tasks_fetched_TF[current_running_tasks[k]] == false)
                    if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
                        sol_Ds_v7 = fetch(tasks[k]);
                        solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
                        tasks_fetched_TF[current_running_tasks[k]] = true
                        current_tasknum = current_running_tasks[k]
                        deleteat!(tasks, k)
                        deleteat!(current_running_tasks, k)
                        curr_numthreads = curr_numthreads+1;
                        print("\nFinished task #")
                        print(current_tasknum)
                        print(", current task k=")
                        print(k)
                        break # break out of this loop, since you have modified current_running_tasks
                    end
                end
            end
        end

        are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
        # Test for concluding the while loop
        are_we_done && break
    end # END while(are_we_done == false)

    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results2))
    print("\n")
    return (duration, sum_of_solutions)
end


function parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
    start_time = Dates.now()
    number_of_threads = Base.Threads.nthreads()
    curr_numthreads = Base.Threads.nthreads()
        
    # Individual ODE solutions will occur over different timeperiods,
    # initial values, and parameters.  We'd just like to load up the 
    # cores for the first jobs in the list, then add jobs as earlier
    # jobs finish.
    tasks = Any[]
    tasks_started_TF = Bool[]
    tasks_fetched_TF = Bool[]
    task_numbers = Any[]
    task_inc = 0
    are_we_done = false
    current_running_tasks = Any[]
    
    # List the tasks
    for i in 1:number_of_solves
        # Temporary u
        solve_results2[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results2[i,i] = 1.0

        task_inc = task_inc + 1
        push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
        push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
        push!(task_numbers, task_inc)
    end
    
    # Total number of tasks
    num_tasks = length(tasks_fetched_TF)

    iteration_number = 0
    while(are_we_done == false)
        iteration_number = iteration_number+1
        
        # Launch tasks when thread (core) is available
        for j in 1:num_tasks
            if (tasks_fetched_TF[j] == false)
                if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
                    # Start a task
                    push!(tasks, Base.Threads.@spawn core_op_simd(solve_results2[j,:], tspan, p_Ds_v7))
                    curr_numthreads = curr_numthreads-1;
                    tasks_started_TF[j] = true;
                    push!(current_running_tasks, task_numbers[j])
                end
            end
        end
        
        # Check for finished tasks
        tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
        if sum(tasks_to_check_TF .== true) > 0
            for k in 1:sum(tasks_to_check_TF)
                if (tasks_fetched_TF[current_running_tasks[k]] == false)
                    if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
                        sol_Ds_v7 = fetch(tasks[k]);
                        solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
                        tasks_fetched_TF[current_running_tasks[k]] = true
                        current_tasknum = current_running_tasks[k]
                        deleteat!(tasks, k)
                        deleteat!(current_running_tasks, k)
                        curr_numthreads = curr_numthreads+1;
                        print("\nFinished task #")
                        print(current_tasknum)
                        print(", current task k=")
                        print(k)
                        break # break out of this loop, since you have modified current_running_tasks
                    end
                end
            end
        end

        are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
        # Test for concluding the while loop
        are_we_done && break
    end # END while(are_we_done == false)

    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results2))
    print("\n")
    return (duration, sum_of_solutions)
end

tspan = (0.0, 1.0)
parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Faster than serial plain version
# (duration, sum_of_solutions)
# (0.351, 8.731365050398926)
# (0.343, 8.731365050398926)
# (0.366, 8.731365050398926)

parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Dramatically slower than serial simd version
# (duration, sum_of_solutions)
# (136.966, 9.61313614002137)
# (141.843, 9.616688089683372)



Thanks again, Nick

  • Times like 0.046 seconds are often a sign that your benchmark optimized away. Do those times still scale as expected (e.g. linearly) with problem size or repeat count? Do all your other times scale properly when you change your problem size? That's a good way to check if the compiler is defeating your benchmark, e.g. optimizing across repeat-loop iterations, or removing work if it doesn't lead to something you print (or assign to a `volatile` variable, if Julia has that.) See also [Idiomatic way of performance evaluation?](https://stackoverflow.com/q/60291987) – Peter Cordes Jun 23 '22 at 13:08
  • The code is totally long but does not actually have a call `@simd` inside. Could you make a toy example specific to your problem. There are many reasons why something can go wrong in a parallel code. – Przemyslaw Szufel Jun 23 '22 at 13:25
  • Thanks very much for the replies! The @simd calls are called from `Ds_v7_simd_sums` -- I put them in here in an attempt to shorten the main-text code: https://gist.githubusercontent.com/nmatzke/f116258c78bd43ab7a448f07c4290516/raw/24a210261fd2e090b8ed27bc64a59a1ff9ec62cd/simd_vs_spawn_setup_v2.jl I gathered the optimized inner loops should be in separate functions. – Nicholas J. Matzke Jun 23 '22 at 20:09
  • "Times like 0.046 seconds are often a sign that your benchmark optimized away. Do those times still scale as expected (e.g. linearly) with problem size or repeat count?" Yes, all of these calculation times scale with problem size -- I posted one where the core operation (core_op) was big enough/slow enough to show the benefit of parallelization, at least for non-simd code. I expect that for the simd-code, it will require a bigger problem, e.g. 511 states instead of 127. – Nicholas J. Matzke Jun 23 '22 at 20:12

0 Answers0