I'm trying to learn numba, and so as an intro exercise, I wrote a simple orbit solver:
import numba as nb
import numpy as np
from timeit import default_timer as timer
spec = [('x0', nb.types.float64),
('y0', nb.types.float64),
('vx0', nb.types.float64),
('vy0', nb.types.float64),
('mass', nb.types.float64),
('ax', nb.types.float64),
('ay', nb.types.float64),
('x', nb.types.float64[:]),
('y', nb.types.float64[:]),
('vx', nb.types.float64[:]),
('vy', nb.types.float64[:])]
@nb.jitclass(spec)
class CelestialBody():
def __init__(self, x0, y0, vx0, vy0, mass):
self.x0 = x0
self.y0 = y0
self.vx0 = vx0
self.vy0 = vy0
self.mass = mass
self.ax = 0.0
self.ay = 0.0
@nb.jit(nopython=True, cache=True)
def orbit(bodies, delta_t, nsteps):
# Set up position arrays
for j in range(len(bodies)):
bodies[j].x = np.zeros(nsteps, dtype=np.float64)
bodies[j].y = np.zeros(nsteps, dtype=np.float64)
bodies[j].vx = np.zeros(nsteps, dtype=np.float64)
bodies[j].vy = np.zeros(nsteps, dtype=np.float64)
bodies[j].x[0] = bodies[j].x0
bodies[j].y[0] = bodies[j].y0
bodies[j].vx[0] = bodies[j].vx0
bodies[j].vy[0] = bodies[j].vy0
# Loop over every time step (skip 0 since we have x0 and y0)
for i in range(0, nsteps-1):
# Get gravitational acceleration for each body at current time
for j in range(len(bodies)):
# Reset accelerations
bodies[j].ax = 0.0
bodies[j].ay = 0.0
for k in range(len(bodies)):
if j != k:
# Get distance between objects
dx = bodies[j].x[i] - bodies[k].x[i]
dy = bodies[j].y[i] - bodies[k].y[i]
d = np.sqrt(dx**2. + dy**2.)
# Get acceleration
a = -bodies[k].mass / d**2.
# Separate into x and y components
theta = np.arctan2(dy, dx)
bodies[j].ax += a * np.cos(theta)
bodies[j].ay += a * np.sin(theta)
# Update positions
for j in range(len(bodies)):
bodies[j].vx[i+1] += bodies[j].vx[i] + bodies[j].ax * delta_t
bodies[j].vy[i+1] += bodies[j].vy[i] + bodies[j].ay * delta_t
bodies[j].x[i+1] += bodies[j].x[i] + bodies[j].vx[i] * delta_t +\
0.5 * bodies[j].ax * delta_t**2.
bodies[j].y[i+1] += bodies[j].y[i] + bodies[j].vy[i] * delta_t + 0.5 *\
bodies[j].ay * delta_t**2
return bodies
for i in range(10):
# Set up celestial bodies
sun = CelestialBody(0., 0., 0., 0., 1.)
earth = CelestialBody(1., 0., 0., 6.33, 3.00e-6)
bodies = [sun, earth]
# Set up time info
tf = 100.
delta_t = tf / 365.
nsteps = int(tf / delta_t)
# Orbit
start = timer()
bodies = orbit(bodies, delta_t, nsteps)
end = timer()
print('Time to run: %f' % (end - start))
The code works and runs without numba. When I add numba, I am able to jit both my class and function, and it runs just fine, providing a good speed up. However, when I try to cache the jitt'ed function using cache=True, I get a KeyError:
File "/usr/local/lib/python3.6/dist-packages/numba/caching.py", line 482, in save
data_name = overloads[key]
KeyError: ((reflected list(instance.jitclass.CelestialBody#2cef1b8<x0:float64,
y0:float64,vx0:float64,vy0:float64,mass:float64,ax:float64,ay:float64,
x:array(float64, 1d, A),y:array(float64, 1d, A),vx:array(float64, 1d, A),
vy:array(float64, 1d, A)>), float64, int64), ('x86_64-unknown-linux-gnu',
'skylake', '+adx,+aes,+avx,+avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,
-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,
-avx512vnni,-avx512vpopcntdq,+bmi,+bmi2,-cldemote,+clflushopt,-clwb,-clzero,+cmov,
+cx16,+f16c,+fma,-fma4,+fsgsbase,-gfni,+invpcid,-lwp,+lzcnt,+mmx,+movbe,-movdir64b,
-movdiri,-mwaitx,+pclmul,-pconfig,-pku,+popcnt,-prefetchwt1,+prfchw,-ptwrite,
-rdpid,+rdrnd,+rdseed,-rtm,+sahf,+sgx,-sha,-shstk,+sse,+sse2,+sse3,+sse4.1,
+sse4.2,-sse4a,+ssse3,-tbm,-vaes,-vpclmulqdq,-waitpkg,-wbnoinvd,-xop,+xsave,
+xsavec,+xsaveopt,+xsaves'))
I realize most of the above is compiler flags and such and probably unnecessary, but I wasn't sure so I figured I'd include it.
There's also a pickle error:
_pickle.PicklingError: Can't pickle <class '__main__.CelestialBody'>: it's not the same object as __main__.CelestialBody
I've tried looking at this question, but as far as I can tell there's no import error, and I haven't messed with any of the modules I'm importing. I'm also not running in a jupyter notebook, just a terminal. My guess is that it has something to do with the class "signature" before and after it's compiled, and pickle is getting confused about the change. I am able to get caching to work when a class is not used.
I'm using Python version 3.6.7, numpy version 1.15.4, and numba version 0.42.1
So, my question is what is causing this pickle error that is preventing caching? Thank you!