49

I plotted the eigenvectors of some 3D-data and was wondering if there is currently (already) a way to put arrowheads on the lines? Would be awesome if someone has a tip for me. enter image description here

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

####################################################
# This part is just for reference if
# you are interested where the data is
# coming from
# The plot is at the bottom
#####################################################

# Generate some example data
mu_vec1 = np.array([0,0,0])
cov_mat1 = np.array([[1,0,0],[0,1,0],[0,0,1]])
class1_sample = np.random.multivariate_normal(mu_vec1, cov_mat1, 20)

mu_vec2 = np.array([1,1,1])
cov_mat2 = np.array([[1,0,0],[0,1,0],[0,0,1]])
class2_sample = np.random.multivariate_normal(mu_vec2, cov_mat2, 20)

# concatenate data for PCA
samples = np.concatenate((class1_sample, class2_sample), axis=0)

# mean values
mean_x = mean(samples[:,0])
mean_y = mean(samples[:,1])
mean_z = mean(samples[:,2])

#eigenvectors and eigenvalues
eig_val, eig_vec = np.linalg.eig(cov_mat)

################################
#plotting eigenvectors
################################    

fig = plt.figure(figsize=(15,15))
ax = fig.add_subplot(111, projection='3d')

ax.plot(samples[:,0], samples[:,1], samples[:,2], 'o', markersize=10, color='green', alpha=0.2)
ax.plot([mean_x], [mean_y], [mean_z], 'o', markersize=10, color='red', alpha=0.5)
for v in eig_vec:
    ax.plot([mean_x, v[0]], [mean_y, v[1]], [mean_z, v[2]], color='red', alpha=0.8, lw=3)
ax.set_xlabel('x_values')
ax.set_ylabel('y_values')
ax.set_zlabel('z_values')

plt.title('Eigenvectors')

plt.draw()
plt.show()
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158

3 Answers3

76

To add arrow patches to a 3D plot, the simple solution is to use FancyArrowPatch class defined in /matplotlib/patches.py. However, it only works for 2D plot (at the time of writing), as its posA and posB are supposed to be tuples of length 2.

Therefore we create a new arrow patch class, name it Arrow3D, which inherits from FancyArrowPatch. The only thing we need to override its posA and posB. To do that, we initiate Arrow3d with posA and posB of (0,0)s. The 3D coordinates xs, ys, zs was then projected from 3D to 2D using proj3d.proj_transform(), and the resultant 2D coordinates get assigned to posA and posB using .set_position() method, replacing the (0,0)s. This way we get the 3D arrow to work.

The projection steps go into the .draw method, which overrides the .draw method of the FancyArrowPatch object.

This might appear like a hack. However, the mplot3d currently only provides (again, only) simple 3D plotting capacity by supplying 3D-2D projections and essentially does all the plotting in 2D, which is not truly 3D.

import numpy as np
from numpy import *
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        FancyArrowPatch.draw(self, renderer)

####################################################
# This part is just for reference if
# you are interested where the data is
# coming from
# The plot is at the bottom
#####################################################

# Generate some example data
mu_vec1 = np.array([0,0,0])
cov_mat1 = np.array([[1,0,0],[0,1,0],[0,0,1]])
class1_sample = np.random.multivariate_normal(mu_vec1, cov_mat1, 20)

mu_vec2 = np.array([1,1,1])
cov_mat2 = np.array([[1,0,0],[0,1,0],[0,0,1]])
class2_sample = np.random.multivariate_normal(mu_vec2, cov_mat2, 20)

Actual drawing. Note that we only need to change one line of your code, which add an new arrow artist:

# concatenate data for PCA
samples = np.concatenate((class1_sample, class2_sample), axis=0)

# mean values
mean_x = mean(samples[:,0])
mean_y = mean(samples[:,1])
mean_z = mean(samples[:,2])

#eigenvectors and eigenvalues
eig_val, eig_vec = np.linalg.eig(cov_mat1)

################################
#plotting eigenvectors
################################    

fig = plt.figure(figsize=(15,15))
ax = fig.add_subplot(111, projection='3d')

ax.plot(samples[:,0], samples[:,1], samples[:,2], 'o', markersize=10, color='g', alpha=0.2)
ax.plot([mean_x], [mean_y], [mean_z], 'o', markersize=10, color='red', alpha=0.5)
for v in eig_vec:
    #ax.plot([mean_x,v[0]], [mean_y,v[1]], [mean_z,v[2]], color='red', alpha=0.8, lw=3)
    #I will replace this line with:
    a = Arrow3D([mean_x, v[0]], [mean_y, v[1]], 
                [mean_z, v[2]], mutation_scale=20, 
                lw=3, arrowstyle="-|>", color="r")
    ax.add_artist(a)
ax.set_xlabel('x_values')
ax.set_ylabel('y_values')
ax.set_zlabel('z_values')

plt.title('Eigenvectors')

plt.draw()
plt.show()

final_output

Please check this post, which inspired this question, for further details.

Ciro Santilli OurBigBook.com
  • 347,512
  • 102
  • 1,199
  • 985
CT Zhu
  • 52,648
  • 17
  • 120
  • 133
  • This code works in `matplotlib 2.0` without `plt.draw()`. Is that line of code necessary? – Seanny123 Feb 28 '17 at 09:48
  • @Seanny123, optional, `.show()` code could also be optional depends on how the environment is setup. Just for clarity sake I suppose. – CT Zhu Mar 02 '17 at 03:57
  • Fantastic answer. It could be improved if there was a way to control the depth-positioning of the arrows drawn. In my case it is undesirable that the arrows are visible over the data point, but I will look into this. EDIT: `arrow.set_zorder(-1)` does the trick, easy as pie. – KeithWM Oct 26 '18 at 09:00
  • 2
    Sadly the edit queue is currently full; to add compatibility for matplotlib>=3.5 add a `do_3d_projection = draw` alias into the class definition. – mueslo Nov 16 '21 at 18:46
  • Oh, and the draw/do_3d_projection function will also need a return value (for z order stacking) – mueslo Nov 16 '21 at 19:20
  • @mueslo. Followed your advice, still getting some None-to-None comparison errors in 3.5.1. – Mad Physicist Apr 20 '22 at 02:31
  • 1
    @mueslo I try to make this code run in a newer version of matplotlib. So where exactly do I have to add this line of code? Is it a class or instance attribute? And does it refer to self.draw() or to the plt.draw()? – Pickniclas Oct 05 '22 at 21:15
21

Another option: you can also use the plt.quiver function, which allows you to produce arrow vectors pretty easily without any extra imports or classes.

To replicate your example, you would replace:

for v in eig_vec:
    ax.plot([mean_x, v[0]], [mean_y, v[1]], [mean_z, v[2]], color='red', alpha=0.8, lw=3)

with:

for v in eig_vec:
    ax.quiver(
        mean_x, mean_y, mean_z, # <-- starting point of vector
        v[0] - mean_x, v[1] - mean_y, v[2] - mean_z, # <-- directions of vector
        color = 'red', alpha = .8, lw = 3,
    )
Matt
  • 460
  • 6
  • 9
  • 1
    Although using the built-in `quiver` sounds simpler than adding a custom class, it does not support the dtype `float128` as its first six arguments: `X`, `Y`, `Z`, `U`, `V`, and `W`. Because it silently converts the arguments into `float`, they are converted into `float64` in our systems. As a result, if we give it `float128` numbers, they _overflow!_ – Shahrokh Bah Apr 06 '21 at 09:23
  • 1
    Beware, if you are using 3D arrows in matplotlib 3.1.2 or later, you get the [3D matplotlib object](https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html) object and not the 2D quiver @matt links. cf [this answer](https://stackoverflow.com/a/58490404/1391376) – chiffa Sep 30 '21 at 13:06
  • 1
    @ShahrokhBah can you please share your use-case where float32 is inacceptable for plots? Thanks – Gaston Mar 23 '22 at 15:40
  • Nice, and strange that the 3D option is not mentioned in the [documentation](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.quiver.html#matplotlib.axes.Axes.quiver). Unfortunately, many of the style options, like `scale`, `headwidth` and `headaxislength` don't seem to work in 3D... – AstroFloyd Apr 25 '22 at 15:45
  • The problem of `quiver` is its arrow size and length scale with the shaft, meaning it is impossible to have a fix sized arrow. – cyfex Sep 23 '22 at 01:07
7

Newer version of matplotlib throwns AttributeError: 'Arrow3D' object has no attribute 'do_3d_projection' with old definition of Arrow3D. It was asked here by several comments and still remained kind of unclear. You have to add function do_3d_projection(), while draw() is no longer needed. Current code looks like this:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))

        return np.min(zs)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
arrow_prop_dict = dict(mutation_scale=20, arrowstyle='-|>', color='k', shrinkA=0, shrinkB=0)
a = Arrow3D([0, 10], [0, 0], [0, 0], **arrow_prop_dict)
ax.add_artist(a)

plt.show()

Help came from github.

Ruli
  • 2,592
  • 12
  • 30
  • 40