3

I have a piecewise function with 3 parts that I'm trying to write in Python using Numba @jit instruction. The function is calculated over an array. The function is defined by:

@njit(parallel=True)
def f(x_vec):
    N=len(x_vec)
    y_vec=np.zeros(N)
    for i in prange(N):
        x=x_vec[i]
        if x<=2000:
            y=64/x
        elif x>=4000:
            y=np.log(x)
        else:
            y=np.log(1.2*x)
        y_vec[i]=y
    return y_vec

I'm using Numba to make this code very fast and run it on all 8 threads of my CPU.

Now, my question is, if I wanted to define each part of the function separately as f1, f2 and f3, and put those inside the if statements (and still benefit from Numba speed), how can I do that? The reason is that the subfunctions can be more complicated and I don't want to make my code hard to read. I want it to be as fast as this one (or slightly slower but not alot).

In order to test the function, we can use this array:

Np=10000000
x_vec=100*np.power(1e8/100,np.random.rand(Np))
%timeit f(x_vec)  #0.06sec on intel core i7 3610

For completionism, the following libraries are called:

import numpy as np
from numba import njit, prange

So in this case, the functions would be:

def f1(x):
    return 64/x
def f2(x):
    return np.log(x)
def f3(x):
    return np.log(1.2*x)

The actual functions are these, which are for smooth pipe friction factor for laminar, transition and turbulent regimes:

@njit
def f1(x):
    return 64/x

@njit
def f2(x):
    #x is the Reynolds number(Re), y is the Darcy friction(f)
    #for transition, we can assume Re=4000 (max possible friction)
    y=0.02
    y=(-2/np.log(10))*np.log(2.51/(4000*np.sqrt(y)))
    return 1/(y*y)

@njit
def f3(x): #colebrook-white approximation
    #x is the Reynolds number(Re), y is the Darcy friction(f)
    y=0.02
    y=(-2/np.log(10))*np.log(2.51/(x*np.sqrt(y)))
    return 1/(y*y)

Thanks for contributions from everyone. This is the numpy solution (the last tree lines are slow for some reason, but doesn't need warmup):

y = np.empty_like(x_vec)

a1=np.where(x_vec<=2000,True,False)
a3=np.where(x_vec>=4000,True,False)
a2=~(a1 | a3)

y[a1] = f1(x_vec[a1])
y[a2] = f2(x_vec[a2])
y[a3] = f3(x_vec[a3])

The fastest Numba solution, allowing for passing function names and taking advantage of prange (but hindered by jit warmup) is this, which can be as fast as the first solution (top of the question):

@njit(parallel=True)
def f(x_vec,f1,f2,f3):
    N = len(x_vec)
    y_vec = np.zeros(N)
    for i in prange(N):
        x=x_vec[i]
        if x<=2000:
            y=f1(x)
        elif x>=4000:
            y=f3(x)
        else:
            y=f2(x)
        y_vec[i]=y
    return y_vec
dani
  • 147
  • 2
  • 10
  • 1
    If your subfunctions are (and can be) also njitted, they will still be fast. You might also want to use `numba.prange` instead of `range`. – Jan Christoph Terasa Oct 28 '20 at 16:11
  • 1
    To be able to answer this, we might have to see your intended subfunctions `f1`, `f2` and `f3` as well. – Jan Christoph Terasa Oct 28 '20 at 16:37
  • @JanChristophTerasa the prange already made it a lot faster. The f1, f2 and f3 are the same as the ones defined in the question. I added the subfunctions to the question. (I can post the actual functions but they can be long but they are basically made of a bunch of multiplications and log calls). – dani Oct 28 '20 at 16:42
  • 1
    You can probably get a tad faster if you use `math.log()` in Numba for scalar arguments. – norok2 Oct 28 '20 at 19:58
  • @norok2 jit makes it ridiculously fast but I just realized the 1-second warmup might not be worth it since I'm only doing 10 iterations. I might have to go back to pure numpy. Anyway, in jit I'll change to math.log and see what happens. Thank you – dani Oct 28 '20 at 20:00
  • 1
    Also consider the options inline='always' to always inline these small functions and error_model='numpy' to disable division by zero checking which also has some overhead. Additionally initializing `y_vec` with zeros before overwriting every entry does not make much sense. It is enough to just allocate memory with `np.empty`. – max9111 Oct 29 '20 at 16:02

2 Answers2

3

Is this too slow? This can be done in pure numpy, by avoiding loops and using masks for indexing:

def f(x):
    y = np.empty_like(x)
    
    mask = x <= 2000
    y[mask] = 64 / x[mask]
    
    mask = (x > 2000) & (x < 4000)
    y[mask] = np.log(1.2 * x[mask])
    
    mask = x >= 4000
    y[mask] = np.log(x[mask])

    return y

You can also run the "else" case by first applying the middle part without any mask to the whole array, it's probably a bit slower:

def f_else(x):
    y = np.log(1.2 * x)
    
    mask = x <= 2000
    y[mask] = 64 / x[mask]
    
    mask = x >= 4000
    y[mask] = np.log(x[mask])

    return y

With

Np=10000000
x_vec=100*np.power(1e8/100,np.random.rand(Np))

I get (laptop with i7-8850H with 6 + 6VT cores)

f1: 1 loop, best of 5: 294 ms per loop
f_else: 1 loop, best of 5: 400 ms per loop

If your intended subfunctions are mainly numpy-operations this will still be fast.

Jan Christoph Terasa
  • 5,781
  • 24
  • 34
  • Thanks Jan, that's very helpful, actually a question I had in mind for a long time was answered by your answer, that with mask I can actually apply the function to a part of my array and totally avoid the if conditions (this is fancy indexing, right?). This is still slower than the jit implementation, but I might find a use for it especially if I try it with f1, f2 and f3 with jit. – dani Oct 28 '20 at 16:48
  • 2
    You can use the `where` argument to ufuncs instead of boolean indexing for more speed. Also, define mask1 and mask2 for two of the conditions. I'm pretty sure `~(mask1 | mask2)` is a faster way to get the else mask. Also can be done with a single buffer – Mad Physicist Oct 28 '20 at 16:50
  • 1
    @dani. Boolean indexing. Fancy indexing is integer indices – Mad Physicist Oct 28 '20 at 16:51
  • 2
    You can probably get a faster `f_else` by using the fastest computation instead (i.e. `64 / x`) as the base case. – norok2 Oct 28 '20 at 17:00
  • 1
    @MadPhysicist Thanks for the `where` tip, very helpful. @norok Probably. I didn't particularly optimize for speed, I just wanted to demonstrate that usually you won't need numba for problems which can be formulated in pure numpy. This is especially true if the "warm-up" time of the JIT dominates your run-time. – Jan Christoph Terasa Oct 28 '20 at 17:48
  • @JanChristophTerasa the JIT warmup is an absolute nightmare but since I have about 10 iterations at least, it's worth it. Thanks for the answer. – dani Oct 28 '20 at 18:33
  • 1
    @JanChristophTerasa is there a way to prewarmup the jit? Like compiling the code or something :))) – dani Oct 28 '20 at 20:07
  • @dani I think there is the [Ahead-of-Time compilation](https://numba.readthedocs.io/en/stable/reference/aot-compilation.html#aot-compilation) for that, but I've never used it. – Alexis Cllmb Sep 08 '22 at 17:45
3

You can write f() to accept function parameters, e.g.:

@njit
def f(arr, f1, f2, f3):
    N = len(arr)
    y_vec = np.zeros(N)
    for i in range(N):
        x = x_vec[i]
        if x <= 2000:
            y = f1(x)
        elif x >= 4000:
            y = f2(x)
        else:
            y = f3(x)
        y_vec[i] = y
    return y_vec

Make sure that the function you pass are Numba compatible.

norok2
  • 25,683
  • 4
  • 73
  • 99
  • 1
    Also, you probably want to parametrize the thresholds as well. – norok2 Oct 28 '20 at 16:54
  • Thanks, it also worked without passing the function names. But I imagine passing the names would give me more flexibility later on as I will be able to pass different functions to it. – dani Oct 28 '20 at 17:23