4

Ahoy!

I am trying to plot vectors in 3D using the matplotlib quiver function. To help visualize them, I also would like to plot orthogonal axes centered at the origin.

Ideally, I would like to move the so-called spines, but according to this SO post, there is no easy fix for that.

I ended up plotting the axes as three vectors along x, y & z (see my code below), but I can't help but think this is a terrible solution... Any input will be greatly appreciated.

Here's the code:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

f = plt.figure(1)
ax=plt.gca()
soa = np.array([[0, 0, 0, 1, 0, 0],[0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 1]])
X, Y, Z, U, V, W = zip(*soa)
soa2 = np.array([[0,0,0,np.sqrt(2)/2,np.sqrt(2)/2,np.sqrt(2)/2]])
I, J, K, F, G, H = zip(*soa2)

fig = plt.figure()
ax=Axes3D(fig)
ax.quiver(X, Y, Z, U, V, W, color='black')
ax.quiver(I, J, K, F, G, H)
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
f.show()

And here's the image returned by this script:

enter image description here

Mad Physicist
  • 107,652
  • 25
  • 181
  • 264
Sheldon
  • 4,084
  • 3
  • 20
  • 41

1 Answers1

1

I'd rather not use quiver because it doesn't handle float128 dtypes properly for its input arguments X, Y, Z, U, V, and W. In fact, it silently converts these inputs to float, which is usually float64 in our systems. As a result, float128 inputs result in overflow!

Instead, I'd like to use CT Zhu's brief class Arrow3D in this wonderful answer. It works flawlessly with float128 coordinates and provides various arrow styles.

With that help, I developed this function to draw X, Y, and Z axes at a graph center:

import numpy as np
import matplotlib.pyplot as plt

from Arrow3D import Arrow3D


def draw_xyz_axes_at_center(mpl_ax):

    # Compute max_lim based on plotted data
    x_lim = abs(max(mpl_ax.get_xlim(), key=abs))
    y_lim = abs(max(mpl_ax.get_ylim(), key=abs))
    z_lim = abs(max(mpl_ax.get_zlim(), key=abs))
    max_lim = max(x_lim, y_lim, z_lim)

    # Position xyz axes at the center
    mpl_ax.set_xlim(-max_lim, max_lim)
    mpl_ax.set_ylim(-max_lim, max_lim)
    mpl_ax.set_zlim(-max_lim, max_lim)

    # Draw xyz axes
    axes = ['x', 'y', 'z']
    for i, axis in enumerate(axes):
        start_end_pts = np.zeros((3, 2))
        start_end_pts[i] = [-max_lim, max_lim]

        # Draw axis
        xs, ys, zs = start_end_pts[0], start_end_pts[1], start_end_pts[2]

        a = Arrow3D(xs, ys, zs, 
                    mutation_scale=20, arrowstyle='-|>', color='black')
        mpl_ax.add_artist(a)

        # Add label
        end_pt_with_padding = start_end_pts[:, 1] * 1.1

        mpl_ax.text(*end_pt_with_padding,
                    axis,
                    horizontalalignment='center',
                    verticalalignment='center',
                    color='black')

To draw a vector:

def draw_vector(mpl_ax, v):
    xs = [0, v[0]]
    ys = [0, v[1]]
    zs = [0, v[2]]

    a = Arrow3D(xs, ys, zs, 
                mutation_scale=20, arrowstyle='->', color='#1f77b4')
    mpl_ax.add_artist(a)

    # Axes limits automatically include the coordinates of all plotted data
    # but not Arrow3D artists. That's actually why this point is plotted.
    mpl_ax.plot(*v, '.', color='#1f77b4')

Let's use them:

ax = plt.figure(figsize=(7, 7)).add_subplot(projection='3d')

draw_vector(ax, np.array([2, 3, 5]))

draw_xyz_axes_at_center(ax)

ax.set_xlabel('x axis')
ax.set_ylabel('y axis')
ax.set_zlabel('z axis')

plt.show()

Output: enter image description here

I used Python 3 by the way and didn't test it on Python 2.

Shahrokh Bah
  • 332
  • 3
  • 5