0

I want to make RK4 with the numba for speed-up. I'm a beginner using the numba. Why can't the numba understand my code?

The simple code is following

in swing.py

@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt):
    t = np.arange(0,t_end, dt, dtype=np.float64)
    X  = np.zeros((t.shape[0], X0.shape[0]))
    X[0] = X0
    hdt = dt*.5
    for i in range(t.shape[0]-1):
        t1 = t[i]
        x1 = X[i]
        k1 = func(t[i], X[i])
        
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        k2 = func(t2, x2)
        
        t3 = t[i] + hdt
        x3 = X[i] + hdt * k2
        k3 = func(t3, x3)
        
        t4 = t[i] + dt
        x4 = X[i] + dt * k3
        k4 = func(t4, x4)
        X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
    return X

# dyummy function for test
@numba.jit(nopython=True)
def fff(t, X):
    t = 1
    X = 3
    res = [0]
    res.append(t*X)
    return res

The main code for running.

import numpy as np
import numba

swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)

The error message following: But I can not understand what isn't correct in this simple code.

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function mul>) found for signature:
 
 >>> mul(float64, list(int64)<iv=[0]>)
 
There are 14 candidate implementations:
  - Of which 12 did not match due to:
  Overload of function 'mul': File: <numerous>: Line N/A.
    With argument(s): '(float64, list(int64)<iv=None>)':
   No match.
  - Of which 2 did not match due to:
  Operator Overload in function 'mul': File: unknown: Line unknown.
    With argument(s): '(float64, list(int64)<iv=None>)':
   No match for registered cases:
    * (int64, int64) -> int64
    * (int64, uint64) -> int64
    * (uint64, int64) -> int64
    * (uint64, uint64) -> uint64
    * (float32, float32) -> float32
    * (float64, float64) -> float64
    * (complex64, complex64) -> complex64
    * (complex128, complex128) -> complex128

During: typing of intrinsic-call at /disk/disk2/youngjin/workspace/workspace/DS/Inference/MCMC/Swing/swing.py (36)

File "swing.py", line 36:
def RK4(func, t_end, X0, dt):
    <source elided>
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        ^

Do you find the reason and solution

  • Your fff function returns a list, and that of the wrong size. In the RK4 step you expect a type with vector arithmetic of the same dimension as the x state vectors. So these do not fit together even without numba. Try first to get everything running without numba, the error messages will be more stringent. (note that you have to read the error messages back-to-front.) – Lutz Lehmann Sep 22 '22 at 19:39
  • Thanks for finding my stupid things! I am testing my code with the exact function! – Young Jin Kim Sep 23 '22 at 02:05
  • @LutzLehmann My original code has a problem with JIT in class. So, I can understand why can not use a function in a class with the numba. Thanks for your help!! XD – Young Jin Kim Sep 23 '22 at 04:19
  • See also [this previous discussion](https://stackoverflow.com/questions/42838103/how-can-i-use-cython-well-to-solve-a-differential-equation-faster/42840933#42840933) on how to speed up RK4 or other integrators in python. – Lutz Lehmann Sep 23 '22 at 04:27
  • 1
    @LutzLehmann The numbalsoda is the best solution in python for boosting! ;) Thanks – Young Jin Kim Sep 23 '22 at 05:40

1 Answers1

0

The solution

in mycode.py
import numpy as np
from scipy import integrate
from typing import Union, List
import numba

def AdjMtoAdjL(adjM: np.ndarray) -> list:
    return [np.argwhere(adjM[:,i] > 0).flatten() for i in range(len(adjM))]
def AdjMtoEdgL(adjM: np.ndarray) -> np.ndarray:
    return np.argwhere(adjM > 0)

@numba.jit(nopython=True)
# def swing(t, y, model_param, model):
def swing(t, y, phi, m, gamma, P, K, model):
    if model == "swing":
        T, O = y
        T = np.array([T])
        O = np.array([O])
    else:
        T = y

    # Get Interaction
    Interaction = K*np.sin(T-phi)
    """
    \dot{\theta} &= \omega \\
    \dot{\omega} &= \frac{1}{m}(P-\gamma\omega+\Sigma K\sin(\theta-\phi))
    """
    if model == "swing":
        dT = O
        dO = 1/m*(P - gamma*O - Interaction)
        dydt = np.concatenate((dT, dO))#, dtype=np.float64)
    else:
        dydt = P + Interaction
    return dydt

@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt, phi, m, gamma, P, K, model):
    t = np.arange(0,t_end, dt, dtype=np.float64)
    X  = np.zeros((t.shape[0], X0.shape[0]))
    X[0] = X0
    hdt = dt*.5
    for i in range(t.shape[0]-1):
        t1 = t[i]
        x1 = X[i]
        k1 = func(t[i], X[i], phi, m, gamma, P, K, model)
        
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        k2 = func(t2, x2, phi, m, gamma, P, K, model)
        
        t3 = t[i] + hdt
        x3 = X[i] + hdt * k2
        k3 = func(t3, x3, phi, m, gamma, P, K, model)
        
        t4 = t[i] + dt
        x4 = X[i] + dt * k3
        k4 = func(t4, x4, phi, m, gamma, P, K, model)
        X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
    return X
maincode.ipynb
import networkx as nx
import os
import multiprocessing as mp
from multiprocessing import Pool
import time
import numpy as np
import swing

def multiprocess(Ngrid=101, t_end=30., omega_lim=30, dt=.001, n_cpu=19):
    start = int(time.time())
    
    T_range = np.linspace(0, 2*np.pi, Ngrid)
    O_range = np.linspace(-omega_lim, omega_lim, Ngrid)
    
    paramss = []
    for theta in T_range:
        for omega in O_range:
            y0 = np.hstack((
                theta,  # Theta
                omega,  # Omega
            ))
            params = {}
            params['sparam'] = Swing_Parameters
            params['t_end'] = t_end
            params['init'] = y0
            params['dt'] = dt
            paramss.append(params)
            del([[params]])

    p = Pool(processes=n_cpu)
    result = p.map(solve_func, paramss)
    
    end = int(time.time())
    print("***run time(sec) : ", end-start)
    print("Number of Core : " + str(n_cpu))
    return result

def solve_func(params):
    Swing_Parameters = params['sparam']
    t_end = params['t_end']
    y0 = params['init']
    dt = params['dt']
    
    # model = swing.SwingSingle(**Swing_Parameters)
    t_eval = np.arange(0,t_end, dt)
    # solution = integrate.solve_ivp(model, [0,t_end], y0, dense_output=False, 
                       # t_eval=t_eval, vectorized=True, method="LSODA")
    phi = Swing_Parameters["phi"]
    m = Swing_Parameters["m"]
    gamma = Swing_Parameters["gamma"]
    P = Swing_Parameters["P"]
    K = Swing_Parameters["K"]
    _model = Swing_Parameters["model"]

    solution = swing.RK4(swing.swing, t_end, y0, dt, phi, m, gamma, P, K, _model)
    return solution
    
Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05

Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05

Swing_Parameters = {
    "phi": np.pi,
    "m": 1.,
    "gamma": 0.3,
    "P": 2.,
    "K": 8.,
    "model": "swing"
}

model = swing.SwingSingle(**Swing_Parameters)

res = multiprocess(Ngrid=Ngrid, t_end=t_end, omega_lim=omega_lim, dt=dt, n_cpu=19)
  • As it’s currently written, your answer is unclear. Please [edit] to add additional details that will help others understand how this addresses the question asked. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Sep 28 '22 at 02:35