0

I have this piece code, using Numba to speed up processing. Basically, particle_dtype is defined to make code ran using Numba. However, TypingError is reported, saying "Cannot determine Numba type of <class 'function'>". I cannot figure out where is the problem.

import numpy
from numba import njit

particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]}) 


def create_n_random_particles(n, m, domain=1):
    parts = numpy.zeros((n), dtype=particle_dtype)
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts


def distance(se, other):
    return numpy.sqrt(numpy.square(se['x'] - other['x']) + 
                      numpy.square(se['y'] - other['y']) + 
                      numpy.square(se['z'] - other['z']))


parts = create_n_random_particles(10, .001, 1)


@njit
def direct_sum(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            # target['phi'] += source['m'] / r
            target['phi'] = target['phi'] + source['m'] / r
            return(target['phi'])
            
print(direct_sum(parts) ) 

I guess it's because non-supported functions or operations are used somewhere, but I cannot find it. Thanks for your help.

Harry
  • 331
  • 1
  • 4
  • 14

1 Answers1

2

direct_sum which is a JITed function cannot call distance because it is not JITed (pure-Python function). You need to use the decorator @njit on distance too.

Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • Thanks, it fixes my issue. I'm learning numba, and still have some questions. 1) using "@njit" on functions is the only way to make JITed function, right? I used numpy.sqrt and numpy.square on "distance", and I thought it was good to be called. 2) say if there are some other functions called in "direct_sum", "@njit" should be used on all these functions, right? 3) I thought "+=" cannot be recognized in numba, as seen in "direct_sum" (commented). However, it seems it's still good if I use "+=". Thanks. – Harry Mar 09 '22 at 16:17
  • 1) not really, but at least for Numba there is `njit` and `jit` (I advise you to use `njit` or `jit` with the `nopython=True` flag). Note that numpy function are reimplemented in Numba so you do not need to explicitly request them to be JITed. 2) Yes, at least for the function in your code. 3) `+=` should be supported well. Besides this, I am not sure your data structure is efficiently stored in memory. Please consider reading [this post](https://stackoverflow.com/a/71102905/12939557). – Jérôme Richard Mar 09 '22 at 20:28