I'd rather not use quiver
because it doesn't handle float128 dtypes properly for its input arguments X
, Y
, Z
, U
, V
, and W
. In fact, it silently converts these inputs to float, which is usually float64 in our systems. As a result, float128 inputs result in overflow!
Instead, I'd like to use CT Zhu's brief class Arrow3D
in this wonderful answer. It works flawlessly with float128 coordinates and provides various arrow styles.
With that help, I developed this function to draw X, Y, and Z axes at a graph center:
import numpy as np
import matplotlib.pyplot as plt
from Arrow3D import Arrow3D
def draw_xyz_axes_at_center(mpl_ax):
# Compute max_lim based on plotted data
x_lim = abs(max(mpl_ax.get_xlim(), key=abs))
y_lim = abs(max(mpl_ax.get_ylim(), key=abs))
z_lim = abs(max(mpl_ax.get_zlim(), key=abs))
max_lim = max(x_lim, y_lim, z_lim)
# Position xyz axes at the center
mpl_ax.set_xlim(-max_lim, max_lim)
mpl_ax.set_ylim(-max_lim, max_lim)
mpl_ax.set_zlim(-max_lim, max_lim)
# Draw xyz axes
axes = ['x', 'y', 'z']
for i, axis in enumerate(axes):
start_end_pts = np.zeros((3, 2))
start_end_pts[i] = [-max_lim, max_lim]
# Draw axis
xs, ys, zs = start_end_pts[0], start_end_pts[1], start_end_pts[2]
a = Arrow3D(xs, ys, zs,
mutation_scale=20, arrowstyle='-|>', color='black')
mpl_ax.add_artist(a)
# Add label
end_pt_with_padding = start_end_pts[:, 1] * 1.1
mpl_ax.text(*end_pt_with_padding,
axis,
horizontalalignment='center',
verticalalignment='center',
color='black')
To draw a vector:
def draw_vector(mpl_ax, v):
xs = [0, v[0]]
ys = [0, v[1]]
zs = [0, v[2]]
a = Arrow3D(xs, ys, zs,
mutation_scale=20, arrowstyle='->', color='#1f77b4')
mpl_ax.add_artist(a)
# Axes limits automatically include the coordinates of all plotted data
# but not Arrow3D artists. That's actually why this point is plotted.
mpl_ax.plot(*v, '.', color='#1f77b4')
Let's use them:
ax = plt.figure(figsize=(7, 7)).add_subplot(projection='3d')
draw_vector(ax, np.array([2, 3, 5]))
draw_xyz_axes_at_center(ax)
ax.set_xlabel('x axis')
ax.set_ylabel('y axis')
ax.set_zlabel('z axis')
plt.show()
Output:

I used Python 3 by the way and didn't test it on Python 2.