24

I love this correlation matrix from the PerformanceAnalytics R package's chart.Correlation function:

PerformanceAnalytics chart.Correlation result

How can I create this in Python? The correlation matrix plots I've seen are primarily heatmaps, such as this seaborn example.

joelostblom
  • 43,590
  • 17
  • 150
  • 159
Max Ghenis
  • 14,783
  • 16
  • 84
  • 132
  • 3
    Seaborn pairplot is a good start: http://seaborn.pydata.org/generated/seaborn.pairplot.html and look here how to add correlation coeffecients https://stackoverflow.com/questions/30942577/seaborn-correlation-coefficient-on-pairgrid – Karl Anka Jan 07 '18 at 19:16

3 Answers3

46

An alternative solution would be

import matplotlib.pyplot as plt
import seaborn as sns

def corrdot(*args, **kwargs):
    corr_r = args[0].corr(args[1], 'pearson')
    corr_text = f"{corr_r:2.2f}".replace("0.", ".")
    ax = plt.gca()
    ax.set_axis_off()
    marker_size = abs(corr_r) * 10000
    ax.scatter([.5], [.5], marker_size, [corr_r], alpha=0.6, cmap="coolwarm",
               vmin=-1, vmax=1, transform=ax.transAxes)
    font_size = abs(corr_r) * 40 + 5
    ax.annotate(corr_text, [.5, .5,],  xycoords="axes fraction",
                ha='center', va='center', fontsize=font_size)

sns.set(style='white', font_scale=1.6)
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, aspect=1.4, diag_sharey=False)
g.map_lower(sns.regplot, lowess=True, ci=False, line_kws={'color': 'black'})
g.map_diag(sns.distplot, kde_kws={'color': 'black'})
g.map_upper(corrdot)

enter image description here


Now, if you really want to imitate the look of that R plot, you can combine the above with some of the solutions you provided:

import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns
import numpy as np

def corrdot(*args, **kwargs):
    corr_r = args[0].corr(args[1], 'pearson')
    corr_text = round(corr_r, 2)
    ax = plt.gca()
    font_size = abs(corr_r) * 80 + 5
    ax.annotate(corr_text, [.5, .5,],  xycoords="axes fraction",
                ha='center', va='center', fontsize=font_size)

def corrfunc(x, y, **kws):
    r, p = stats.pearsonr(x, y)
    p_stars = ''
    if p <= 0.05:
        p_stars = '*'
    if p <= 0.01:
        p_stars = '**'
    if p <= 0.001:
        p_stars = '***'
    ax = plt.gca()
    ax.annotate(p_stars, xy=(0.65, 0.6), xycoords=ax.transAxes,
                color='red', fontsize=70)

sns.set(style='white', font_scale=1.6)
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, aspect=1.5, diag_sharey=False, despine=False)
g.map_lower(sns.regplot, lowess=True, ci=False,
            line_kws={'color': 'red', 'lw': 1},
            scatter_kws={'color': 'black', 's': 20})
g.map_diag(sns.distplot, color='black',
           kde_kws={'color': 'red', 'cut': 0.7, 'lw': 1},
           hist_kws={'histtype': 'bar', 'lw': 2,
                     'edgecolor': 'k', 'facecolor':'grey'})
g.map_diag(sns.rugplot, color='black')
g.map_upper(corrdot)
g.map_upper(corrfunc)
g.fig.subplots_adjust(wspace=0, hspace=0)

# Remove axis labels
for ax in g.axes.flatten():
    ax.set_ylabel('')
    ax.set_xlabel('')

# Add titles to the diagonal axes/subplots
for ax, col in zip(np.diag(g.axes), iris.columns):
    ax.set_title(col, y=0.82, fontsize=26)

enter image description here

Which is very close to how chart.Correlation() graphs the iris data set in R:

library(PerformanceAnalytics)
chart.Correlation(data.matrix(iris[, -5]), histogram = TRUE, pch=20)

enter image description here

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
joelostblom
  • 43,590
  • 17
  • 150
  • 159
  • I am getting " 'numpy.ndarray' object has no attribute 'name'" error at line "ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')". Everything works in Python 2, but not in Python 3. Do you know how to fix it? – Helena Nov 05 '18 at 19:23
  • 2
    @HelenaGoldfarb Thanks for pointing that out! This is due to a [change in seaborn](https://github.com/mwaskom/seaborn/issues/1562#issuecomment-436073767). I updated the code to work with seaborn 0.9.0. I also changed the regression to lowess and added a rugplot to make it more similar to `chart.Correlation()`. – joelostblom Nov 05 '18 at 23:39
  • @ImportanceOfBeingErnest you did awesome. Can you please little further modify a code. Because in including "hue" argument in sns.PairGrid(), it is overwriting correlation in upper diagonal. Please see the link for image and code "https://i.stack.imgur.com/af9He.png" and "https://i.stack.imgur.com/G1DCY.png" – Girish Kumar Chandora Aug 17 '19 at 05:00
  • Is it possible to use a `hue` but still get the correlation dot for the entire dataset? – BND Aug 06 '20 at 09:40
  • Really cool plot! Definitely going to start using this for my own analysis – Novice Mar 08 '23 at 10:59
16

The cor_matrix function below does this, plus adds a bivariate kernel density plot. Thanks to @karl-anka's comment for getting me started.

import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

sns.set(style='white')
iris = sns.load_dataset('iris')

def corrfunc(x, y, **kws):
  r, p = stats.pearsonr(x, y)
  p_stars = ''
  if p <= 0.05:
    p_stars = '*'
  if p <= 0.01:
    p_stars = '**'
  if p <= 0.001:
    p_stars = '***'
  ax = plt.gca()
  ax.annotate('r = {:.2f} '.format(r) + p_stars,
              xy=(0.05, 0.9), xycoords=ax.transAxes)

def annotate_colname(x, **kws):
  ax = plt.gca()
  ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes,
              fontweight='bold')
  
def cor_matrix(df):
  g = sns.PairGrid(df, palette=['red'])
  # Use normal regplot as `lowess=True` doesn't provide CIs.
  g.map_upper(sns.regplot, scatter_kws={'s':10})
  g.map_diag(sns.histplot, kde=True, kde_kws=dict(cut=3), alpha=.4, edgecolor=(1, 1, 1, .4))
  g.map_diag(annotate_colname)
  g.map_lower(sns.kdeplot, cmap='Blues_d')
  g.map_lower(corrfunc)
  # Remove axis labels, as they're in the diagonals.
  for ax in g.axes.flatten():
    ax.set_ylabel('')
    ax.set_xlabel('')
  return g

cor_matrix(iris)

plot result

Reuel Ribeiro
  • 1,419
  • 14
  • 23
Max Ghenis
  • 14,783
  • 16
  • 84
  • 132
  • 2
    I am getting " 'numpy.ndarray' object has no attribute 'name'" error at line "ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')". Have you seen this before? – Helena Oct 17 '18 at 15:43
  • 1
    @Helena this helped me: https://github.com/mwaskom/seaborn/issues/1562 together with https://datascience.stackexchange.com/questions/57673/how-to-put-the-variable-names-of-pandas-data-frame-on-diagonal-of-seaborn-pairgr – oski86 Nov 21 '19 at 17:29
  • Suggestion queue is full, otherwise I would have like to convert the 2 space indents to 4 space indents, which are very common. –  Dec 07 '21 at 20:36
  • @Max Ghenis, is there any way to adjust the axis ranges to the maximum or minimum values of each plot? I am getting panels where all the curves are flattened due to the linear fit curve (the red shadow expands too much and moves the scale up to limits that are not desirable). Thanks!!! – GEBRU Feb 25 '22 at 17:38
  • 2
    Newer seaborn versions have distplot deprecated. Rather use ```g.map_diag(sns.histplot, kde=True, kde_kws=dict(cut=3), alpha=.4, edgecolor=(1, 1, 1, .4))``` – Daniel Böckenhoff Dec 12 '22 at 16:23
0

To solve the issue " 'numpy.ndarray' object has no attribute 'name'" error at line "ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')" and to keep generality, build an iteration function inside the cor_matrix function and move the annnotate_col function into the cor_matrix function as follow.

def corrfunc(x, y, **kws):
    r, p = stats.pearsonr(x, y)
    p_stars = ''
    if p <= 0.05:  
        p_stars = '*'
    if p <= 0.01:  
        p_stars = '**'
    if p <= 0.001:  
        p_stars = '***'
    ax = plt.gca()
    ax.annotate('r = {:.2f} '.format(r) + p_stars, xy=(0.05, 0.9), ycoords=ax.transAxes)

 
def cor_matrix(df, save=False):
    # ======= NEW ITERATION FUNCTION ====
    label_iter = iter(df).__next__
    # ====================================
    def annotate_colname(x, **kws):
        ax = plt.gca()
        # ===== GHANGE below x.name by label_iter() ======
        ax.annotate(label_iter(), xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')


    g = sns.PairGrid(df, palette=['red'])
    
    # Use normal regplot as `lowess=True` doesn't provide CIs.
    g.map_upper(sns.regplot, scatter_kws={'s':10}, line_kws={"color": "red"})
    g.map_diag(sns.histplot, kde=True)  # fix deprecated message
    g.map_diag(annotate_colname)
    g.map_lower(sns.kdeplot, cmap='Blues_d')
    g.map_lower(corrfunc)
    
    # Remove axis labels, as they're in the diagonals.
    for ax in g.axes.flatten():
        ax.set_ylabel('')
        ax.set_xlabel('')
    if save:
        plt.savefig('corr_mat.png')
    return g