0

I'm currently trying to build an N-body simulation but I'm having a little trouble with plotting the results the way I'd like.

In the code below (with some example data for a few points in an orbit) I'm importing the position and time data and organizing it into a pandas dataframe. To create the 3D animation I use matplotlib's animation class, which works perfectly.

However, the usual way to set up an animation is limited in that you can't customize the points in each frame individually (please let me know if I'm wrong here :p). Since my animation is showing orbiting bodies I would like to vary their sizes and colors. To do that I essentially create a graph for each body and set it's color etc. When it gets to the update_graph function, I iterate over the n bodies, retrieve their individual (x,y,z) coordinates, and update their graphs.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.axes3d import get_test_data
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import pandas as pd

nbodies = 2

x = np.array([[1.50000000e-10, 0.00000000e+00, 0.00000000e+00],
              [9.99950000e-01, 1.00000000e-02, 0.00000000e+00],
              [4.28093585e-06, 3.22964816e-06, 0.00000000e+00],
              [-4.16142210e-01, 9.09335149e-01, 0.00000000e+00],
              [5.10376489e-06, 1.42204430e-05, 0.00000000e+00],
              [-6.53770813e-01, -7.56722445e-01, 0.00000000e+00]])
t = np.array([0.01, 0.01, 2.0, 2.0, 4.0, 4.0])
tt = np.array([0.01, 2.0, 4.0])

x = x.reshape((len(tt), nbodies, 3))

x_coords = x[:, :, 0].flatten()
y_coords = x[:, :, 1].flatten()
z_coords = x[:, :, 2].flatten()

df = pd.DataFrame({"time": t[:] ,"x" : x_coords, "y" : y_coords, "z" : z_coords})
print(df)

def update_graph(num):
    data=df[df['time']==tt[num]] # x,y,z of all bodies at current time
    for n in range(nbodies): # update graphs
        data_n = data[data['x']==x_coords[int(num * nbodies) + n]] # x,y,z of body n
        graph = graphs[n]
        graph.set_data(data_n.x, data_n.y)
        graph.set_3d_properties(data_n.z)
        graphs[n] = graph
    return graphs

plt.style.use('dark_background')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel('x (AU)')
ax.set_ylabel('y (AU)')
ax.set_zlabel('z (AU)')

plt.xlim(-1.5,1.5)
plt.ylim(-1.5,1.5)

# initialize
data=df[df['time']==0]
ms_list = [5, 1]
c_list = ['yellow', 'blue']
graphs = []
for n in range(nbodies):
    graphs.append(ax.plot([], [], [], linestyle="", marker=".", 
                            markersize=ms_list[n], color=c_list[n])[0])

ani = animation.FuncAnimation(fig, update_graph, len(tt), 
                               interval=400, blit=True, repeat=True)

plt.show()

However, doing this gives me the following error:

Traceback (most recent call last):
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/backend_bases.py", line 1194, in _on_timer
    ret = func(*args, **kwargs)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1447, in _step
    still_going = Animation._step(self, *args)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1173, in _step
    self._draw_next_frame(framedata, self._blit)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1193, in _draw_next_frame
    self._post_draw(framedata, blit)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1216, in _post_draw
    self._blit_draw(self._drawn_artists, self._blit_cache)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1231, in _blit_draw
    a.axes.draw_artist(a)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py", line 2661, in draw_artist
    a.draw(self.figure._cachedRenderer)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/matplotlib/artist.py", line 38, in draw_wrapper
    return draw(artist, renderer, *args, **kwargs)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/mpl_toolkits/mplot3d/art3d.py", line 202, in draw
    xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/mpl_toolkits/mplot3d/proj3d.py", line 201, in proj_transform
    vec = _vec_pad_ones(xs, ys, zs)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/mpl_toolkits/mplot3d/proj3d.py", line 189, in _vec_pad_ones
    return np.array([xs, ys, zs, np.ones_like(xs)])
  File "/home/kris/anaconda3/lib/python3.7/site-packages/pandas/core/series.py", line 871, in __getitem__
    result = self.index.get_value(self, key)
  File "/home/kris/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 4405, in get_value
    return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
  File "pandas/_libs/index.pyx", line 80, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 90, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 138, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 997, in pandas._libs.hashtable.Int64HashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1004, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 0
Aborted (core dumped)

I'm not sure what this really means, but I do know the problem is something to do with updating the graphs with only one row of coordinates rather than all three. Because if I instead have

def update_graph(num):
    data=df[df['time']==tt[num]] # x,y,z of all bodies at current time
    for n in range(nbodies): # update graphs
        #data_n = data[data['x']==x_coords[int(num * nbodies) + n]] # x,y,z of body n
        graph = graphs[n]
        graph.set_data(data.x, data.y)  # using data rather than data_n here now
        graph.set_3d_properties(data.z)
        graphs[n] = graph
    return graphs

it actually works, and plots three copies of the bodies with varying colors and sizes on top of each other as you would expect.

Any help would be much appreciated. Thanks!

Kris Walker
  • 163
  • 6
  • Welcome to Stack Overflow! Please take a moment to read [How do I ask a good question?](https://stackoverflow.com/help/how-to-ask). You need to provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) that includes a toy dataset (refer to [How to make good reproducible pandas examples](https://stackoverflow.com/questions/20109391/how-to-make-good-reproducible-pandas-examples)) – Diziet Asahi Nov 09 '20 at 09:55
  • @DizietAsahi Thanks Diziet. I've updated it, how does that look? – Kris Walker Nov 09 '20 at 10:56
  • Your code seem to work fine with the data provided? – Diziet Asahi Nov 09 '20 at 11:05
  • @DizietAsahi Ah silly me, I forgot to re-enter the broken bit after testing. Sorry about that, it should be correct now. – Kris Walker Nov 09 '20 at 11:09

1 Answers1

1

I don't understand why you are going through a pandas DataFrame, when you seem to already have all the data you need in your numpy array. I couldn't reproduce the initial problem, by I propose this solution that uses pure numpy arrays, which may fix the problem:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.axes3d import get_test_data
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import pandas as pd

nbodies = 2

x = np.array([[1.50000000e-10, 0.00000000e+00, 0.00000000e+00],
              [9.99950000e-01, 1.00000000e-02, 0.00000000e+00],
              [4.28093585e-06, 3.22964816e-06, 0.00000000e+00],
              [-4.16142210e-01, 9.09335149e-01, 0.00000000e+00],
              [5.10376489e-06, 1.42204430e-05, 0.00000000e+00],
              [-6.53770813e-01, -7.56722445e-01, 0.00000000e+00]])
t = np.array([0.01, 0.01, 2.0, 2.0, 4.0, 4.0])
tt = np.array([0.01, 2.0, 4.0])
x = x.reshape((len(tt), nbodies, 3))


def update_graph(i):
    data = x[i, :, :]  # x,y,z of all bodies at current time
    for body, graph in zip(data, graphs):  # update graphs
        graph.set_data(body[0], body[1])
        graph.set_3d_properties(body[2])
    return graphs


plt.style.use('dark_background')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel('x (AU)')
ax.set_ylabel('y (AU)')
ax.set_zlabel('z (AU)')

plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)

# initialize
ms_list = [50, 10]
c_list = ['yellow', 'blue']
graphs = []
for n in range(nbodies):
    graphs.append(ax.plot([], [], [], linestyle="", marker=".",
                          markersize=ms_list[n], color=c_list[n])[0])

ani = animation.FuncAnimation(fig, func=update_graph, frames=len(tt),
                              interval=400, blit=True, repeat=True)

plt.show()

enter image description here

Diziet Asahi
  • 38,379
  • 7
  • 60
  • 75