3

I am trying to plot a scatterplot matrix based on the code written by Joe Kington: Is there a function to make scatterplot matrices in matplotlib?

Some people already helped me: Thank you again (especially J.K.).

I am having a last problem: I cannot rotate the ticks of some axis for which numbers overlap (bottom left):

I would like to try to have them vertical but I cannot do it.... Here is my code:

import itertools
import numpy as np
import pylab as plot
import scipy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import axis
import math
from matplotlib import rc
import os
import platform


def main():
    FigSize=8.89
    FontSize=8
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'], FigSize, FontSize,
        linestyle='none', marker='o', color='black', mfc='none', markersize=3,)
    fig.suptitle('Simple Scatterplot Matrix')
    plt.savefig('Plots/ScatterplotMatrix/ScatterplotMatrix2.pdf',format='pdf', dpi=1000, transparent=True, bbox_inches='tight')
    plt.show()


def scatterplot_matrix(data, names, FigSize, FontSize, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""

    legend=['(kPa)','\%','\%','\%']
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(FigSize/2.54,FigSize/2.54))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    sub_labelx_top=[2,4]
    sub_labelx_bottom=[13,15]
    sub_labely_left=[5,13]
    sub_labely_right=[4,12]

    for i, ax in enumerate(axes.flat, start=1):
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.xaxis.set_major_locator(MaxNLocator(prune='both',nbins=4))
        ax.yaxis.set_major_locator(MaxNLocator(prune='both',nbins=4)) #http://matplotlib.org/api/ticker_api.html#matplotlib.ticker.MaxNLocator


        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
            ax.tick_params(direction='out')
            ax.yaxis.set_tick_params(labelsize=0.75*FontSize)
            if i in sub_labely_left:
        ax.yaxis.set_label_position('left')
            ax.set_ylabel('(\%)',fontsize=0.75*FontSize)

        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
            ax.tick_params(direction='out')
            ax.yaxis.set_tick_params(labelsize=0.75*FontSize)
            if i in sub_labely_right:
                ax.yaxis.set_label_position('right')
                if i==4:
                ax.set_ylabel('(kPa)',fontsize=0.75*FontSize)
                else:
                ax.set_ylabel('(\%)',fontsize=0.75*FontSize)

        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
            ax.tick_params(direction='out')
            ax.xaxis.set_tick_params(labelsize=0.75*FontSize)
            if i in sub_labelx_top:
                ax.xaxis.set_label_position('top')
                ax.set_xlabel('(\%)',fontsize=0.75*FontSize)

        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')
            ax.tick_params(direction='out')
            ax.xaxis.set_tick_params(labelsize=0.75*FontSize)

            if i in sub_labelx_bottom:
                ax.xaxis.set_label_position('bottom')

                if i==13:
                ax.set_xlabel('(kPa)',fontsize=0.75*FontSize)
                else:
                ax.set_xlabel('(\%)',fontsize=0.75*FontSize)

             # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[y], data[x], **kwargs)   



    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
            ha='center', va='center',fontsize=FontSize)

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()

My second question is more for the 'fun': how can I make the subplots perfectly squares?

I apologize to Joe Kington; I know my code is way less elegant than his... I just started few weeks ago. If you have any suggestions to improve mine, for example to make it more dynamic, I am very interesting.

Community
  • 1
  • 1
Viktor
  • 67
  • 1
  • 1
  • 6

1 Answers1

4

You can rotate the xtick labels using setp.

from matplotlib.artist import setp

Then after you set the x tick positions for the top row and left column of subplot call:

setp(ax.get_xticklabels(), rotation=90)

To make the size of the subplots equal, you can fig.subplots_adjust to set the area of all the subplots to a square. Something like this:

gridSize = 0.6
leftBound = 0.5 - gridSize/2
bottomBound = 0.1
rightBound = leftBound + gridSize
topBound = bottomBound + gridSize
fig.subplots_adjust(hspace=0.05, wspace=0.05, left=leftBound,
                        bottom=bottomBound, right=rightBound, top=topBound)

If the figure size isn't square, you'll need to change the shape of the grid accordingly. Alternately, you could add each subplot axes individually with fig.add_axes. That will allow you to set the size directly but you'll also have to set the location.

Don't use bbox_inches='tight' to save the figure or you'll lose the title with these setting. You can save like this:

plt.savefig('ScatterplotMatrix.pdf',format='pdf', dpi=1000, transparent=True)

The resulting graph looks like this:

scatter plot matrix

Molly
  • 13,240
  • 4
  • 44
  • 45