64

Example of scatterplot matrix

enter image description here

Is there such a function in matplotlib.pyplot?

hello_there_andy
  • 2,039
  • 2
  • 21
  • 51
hatmatrix
  • 42,883
  • 45
  • 137
  • 231

5 Answers5

124

For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:

from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')

enter image description here

visitor
  • 672
  • 6
  • 17
Roman Pekar
  • 107,110
  • 28
  • 195
  • 197
  • 2
    Hi, how come only part of the subplots have a grid in them? Can that be modified (either all or none)? Thanks – user2808117 Jan 22 '14 at 09:03
  • 5
    +1 That'll teach me to go searching for a Python feature before looking to see if it's already in pandas. Step 1: Always ask, does it already exist in pandas? `pd.scatter_matrix(df); plt.show()`. Incredible. – Jarad Dec 02 '15 at 20:59
  • 2
    Placing a kde in the matplotlib scatterplot matrix is extreme sport. I love pandas. – Lorinc Nyitrai Sep 18 '16 at 22:24
  • Does anyone know where the actual API documentation for `pd.tools.plotting.scatter_matrix` is? Everywhere that I look I can only find that one example - I can't find the optional arguments for the life of me... – Owen Nov 24 '16 at 10:58
  • 1
    As of [pandas 0.20](http://pandas.pydata.org/pandas-docs/version/0.20/whatsnew.html#deprecate-plotting), `scatter_matrix` has been moved to `pandas.plotting.scatter_matrix`. – joelostblom May 10 '17 at 15:58
32

Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.

I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)

As an example:

import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

def scatterplot_matrix(data, names, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[x], data[y], **kwargs)

    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()

enter image description here

Joe Kington
  • 275,208
  • 71
  • 604
  • 463
  • 3
    Wow, many new functions! Yes, not too difficult when you have mastery of the module... but not as simple as calling `pairs` as in R. :) – hatmatrix Oct 29 '11 at 20:59
  • True! R has a lot more specialized functions, in my (limited!) experience with it. Matplotlib has a slightly more DIY approach. (Or certainly a lot fewer specialized statistical plotting functions, at any rate.) – Joe Kington Oct 29 '11 at 21:01
  • Certainly I feel this way. I'm sticking with the Python trio (for now) in hopes though that it offers other advantages... – hatmatrix Oct 29 '11 at 21:37
  • 2
    In my opinion, the big advantage is python's flexibility. R is a fantastic domain specific language, and if you're just wanting to do statistical analysis, it's unmatched. Python is a nice general programming language, and you'll really start to see the benefits with larger programs. Once you begin to want a program with an interactive gui that grabs data from the web, parses some random binary file format, does your analysis, and plots it all up, a general programming language beings to show a lot of advantages. Of course, that's true for a lot of languages, but I prefer python. :) – Joe Kington Oct 29 '11 at 21:46
  • The funny thing is, R can also do many of those things also, though perhaps not as well as python. I'm working with many files, large data sets, and so on, and I've been using python (without numpy) for some time now for shell scripting/text processing, to feed into R. I thought to bring everything (including analysis) under the python umbrella; python seems to show efficiency gains (not trivial) in that arrays are not copied over and over. But can be very verbose when it comes to numpy or matplotlib... – hatmatrix Oct 30 '11 at 00:58
  • 1
    @Joe Kington, firstly, thanks for this example (I use it regularly) and all your other mpl examples! A couple of points: 1. For those wishing to match R, the x and y values are backwards: change plot `axes[x,y]` to `axes[y,x]`. 2. set `sharex='col', sharey='row'` in subplots() 3. diagonal affects the tick limits, so either set the limits or plot `axes[i,i].plot(data[i], data[i], linestyle='None')` 4. if data is in row, col format, then input must be transposed, `data.T` – CNK Jun 26 '13 at 21:37
17

You can also use Seaborn's pairplot function:

import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
SiHa
  • 7,830
  • 13
  • 34
  • 43
sushmit
  • 4,369
  • 2
  • 35
  • 38
  • the annoying part about seaborn is that it's centered around pandas DataFrames. If you have a NumPy array, this workaround feels annoying, and if you already have a pandas DataFrame, why not just using pandas' in-build scatter_matrix method? –  Feb 17 '18 at 01:23
  • Unfortunately, it does not allow scatterplot matrices formed by two distinct groups of variables. It just gives vars vs vars plot. This complicates analysis for medium-sized and large datasets. – KKS May 06 '19 at 09:54
10

Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.

  1. [FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.

  2. [FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.

  3. Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.

Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.

import itertools
import numpy as np
import matplotlib.pyplot as plt

def scatterplot_matrix(data, names=[], **kwargs):
    """
    Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid.
    """
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.0, wspace=0.0)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            # FIX #1: this needed to be changed from ...(data[x], data[y],...)
            axes[x,y].plot(data[y], data[x], **kwargs)

    # Label the diagonal subplots...
    if not names:
        names = ['x'+str(i) for i in range(numvars)]

    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
    # correct axes limits, so we pull them from other axes
    if numvars%2:
        xlimits = axes[0,-1].get_xlim()
        ylimits = axes[-1,0].get_ylim()
        axes[-1,-1].set_xlim(xlimits)
        axes[-1,-1].set_ylim(ylimits)

    return fig

if __name__=='__main__':
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.

tisimst
  • 226
  • 4
  • 6
  • Thanks, I was having the problems with @Joe Kington's code until I saw your answer. It saved me some debugging time :) – chutsu Apr 06 '14 at 15:57
  • Any idea, how can I make this function faster, I need to generate a big scatter plot matrix around 100 vars and this method is very slow. – MARK Feb 29 '16 at 22:39
5

While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:

import rpy
import numpy as np

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    mpg = data[0,:]
    disp = data[1,:]
    drat = data[2,:]
    wt = data[3,:]
    rpy.set_default_mode(rpy.NO_CONVERSION)

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

    # Figure saved as eps
    rpy.r.postscript('pairsPlot.eps')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    # Figure saved as png
    rpy.r.png('pairsPlot.png')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    rpy.set_default_mode(rpy.BASIC_CONVERSION)


if __name__ == '__main__': main()

I can't post an image to show the result :( sorry!

Slothworks
  • 1,083
  • 14
  • 18
omun
  • 113
  • 1
  • 7