4

The post Get data points from Seaborn distplot describes how you can get data elements using sns.distplot(x).get_lines()[0].get_data(), sns.distplot(x).patches and [h.get_height() for h in sns.distplot(x).patches]

But how can you do this if you've used multiple layers by plotting the data in a loop, such as:

Snippet 1

for var in list(df):
    print(var)
    distplot = sns.distplot(df[var])

Plot

enter image description here

Is there a way to retrieve the X and Y values for both linecharts and the bars?


Here's the whole setup for an easy copy&paste:

#%%
# imports
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pylab
pylab.rcParams['figure.figsize'] = (8, 4)
import seaborn as sns
from collections import OrderedDict

# Function to build synthetic data
def sample(rSeed, periodLength, colNames):

    np.random.seed(rSeed)
    date = pd.to_datetime("1st of Dec, 1999")   
    cols = OrderedDict()

    for col in colNames:
        cols[col] = np.random.normal(loc=0.0, scale=1.0, size=periodLength)
    dates = date+pd.to_timedelta(np.arange(periodLength), 'D')

    df = pd.DataFrame(cols, index = dates)
    return(df)

# Dataframe with synthetic data
df = sample(rSeed = 123, colNames = ['X1', 'X2'], periodLength = 50)

# sns.distplot with multiple layers
for var in list(df):
    myPlot = sns.distplot(df[var])

Here's what I've tried:

Y-values for histogram:

If I run:

barX = [h.get_height() for h in myPlot.patches]

Then I get the following list of lenght 11:

[0.046234272703757885,
 0.1387028181112736,
 0.346757045278184,
 0.25428849987066837,
 0.2542884998706682,
 0.11558568175939472,
 0.11875881712519201,
 0.3087729245254993,
 0.3087729245254993,
 0.28502116110046083,
 0.1662623439752689]

And this seems reasonable since there seems to be 6 values for the blue bars and 5 values for the red bars. But how do I tell which values belong to which variable?

Y-values for line:

This seems a bit easier than the histogram part since you can use myPlot.get_lines()[0].get_data() AND myPlot.get_lines()[1].get_data() to get:

Out[678]: 
(array([-4.54448949, -4.47612134, -4.40775319, -4.33938504, -4.27101689,
         ...
         3.65968859,  3.72805675,  3.7964249 ,  3.86479305,  3.9331612 ,
         4.00152935,  4.0698975 ,  4.13826565]),
 array([0.00042479, 0.00042363, 0.000473  , 0.00057404, 0.00073097,
        0.00095075, 0.00124272, 0.00161819, 0.00208994, 0.00267162,
        ...
        0.0033384 , 0.00252219, 0.00188591, 0.00139919, 0.00103544,
        0.00077219, 0.00059125, 0.00047871]))

myPlot.get_lines()[1].get_data()

Out[679]: 
(array([-3.68337423, -3.6256517 , -3.56792917, -3.51020664, -3.4524841 ,
        -3.39476157, -3.33703904, -3.27931651, -3.22159398, -3.16387145,
         ...
         3.24332952,  3.30105205,  3.35877458,  3.41649711,  3.47421965,
         3.53194218,  3.58966471,  3.64738724]),
 array([0.00035842, 0.00038018, 0.00044152, 0.00054508, 0.00069579,
        0.00090076, 0.00116922, 0.00151242, 0.0019436 , 0.00247792,
        ...
        0.00215912, 0.00163627, 0.00123281, 0.00092711, 0.00070127,
        0.00054097, 0.00043517, 0.00037599]))

But the whole thing still seems a bit cumbersome. So does anyone know of a more direct approach to perhaps retrieve all data to a dictionary or dataframe?

vestland
  • 55,229
  • 37
  • 187
  • 305
  • The approach here is working, right? It's just a little cumbersome. But that's expected for such detour. You wouldn't usually first plot something to get the data out, but rather the inverse, first get the data, then plot it. – ImportanceOfBeingErnest Sep 26 '18 at 12:00
  • 1
    @ImportanceofBeinErnest, it's working for the most part, but I've had little success with the x values for the histograms. And normally I would say that you are 100% right about the "data first, then plot" approach. But seaborn is able to produce a bunch of very informational charts where it, to me at least, would be fantastic to grab all the data in an efficient way. Why? Because I often find myself in the situation that I'm asked to reproduce python / matplotlib / seaborn plots with other visualization tools. And if you ask me why, I'd wish I'd had a good answer to that... – vestland Sep 26 '18 at 12:17
  • 1
    Since the bars have different color you can use this information to find out which bar belongs to which set of data (`h.get_facecolor()`). – ImportanceOfBeingErnest Sep 26 '18 at 12:39

1 Answers1

3

I was just getting the same need of retrieving data from a seaborn distribution plot, what worked for me was to call the method .findobj() on each iteration's graph. Then, one can notice that the matplotlib.lines.Line2D object has a get_data() method, this is similar as what you've mentioned before for myPlot.get_lines()[1].get_data().

Following your example code

data = []
for idx, var in enumerate(list(df)):
    myPlot = sns.distplot(df[var])
    
    # Fine Line2D objects
    lines2D = [obj for obj in myPlot.findobj() if str(type(obj)) == "<class 'matplotlib.lines.Line2D'>"]
    
    # Retrieving x, y data
    x, y = lines2D[idx].get_data()[0], lines2D[idx].get_data()[1]
    
    # Store as dataframe 
    data.append(pd.DataFrame({'x':x, 'y':y}))

Notice here that the data for the first sns.distplot plot is stored on the first index of lines2D and the data for the second sns.distplot is stored on the second index. I'm not really sure about why this happens this way, but if you were to consider more than two plots, then you will access each sns.distplot data by calling Lines2D on it's respective index.

Finally, to verify one can plot each distplot

plt.plot(data[0].x, data[0].y)

enter image description here

Miguel Trejo
  • 5,913
  • 5
  • 24
  • 49