The post Get data points from Seaborn distplot describes how you can get data elements using sns.distplot(x).get_lines()[0].get_data()
, sns.distplot(x).patches
and [h.get_height() for h in sns.distplot(x).patches]
But how can you do this if you've used multiple layers by plotting the data in a loop, such as:
Snippet 1
for var in list(df):
print(var)
distplot = sns.distplot(df[var])
Plot
Is there a way to retrieve the X and Y values for both linecharts and the bars?
Here's the whole setup for an easy copy&paste:
#%%
# imports
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pylab
pylab.rcParams['figure.figsize'] = (8, 4)
import seaborn as sns
from collections import OrderedDict
# Function to build synthetic data
def sample(rSeed, periodLength, colNames):
np.random.seed(rSeed)
date = pd.to_datetime("1st of Dec, 1999")
cols = OrderedDict()
for col in colNames:
cols[col] = np.random.normal(loc=0.0, scale=1.0, size=periodLength)
dates = date+pd.to_timedelta(np.arange(periodLength), 'D')
df = pd.DataFrame(cols, index = dates)
return(df)
# Dataframe with synthetic data
df = sample(rSeed = 123, colNames = ['X1', 'X2'], periodLength = 50)
# sns.distplot with multiple layers
for var in list(df):
myPlot = sns.distplot(df[var])
Here's what I've tried:
Y-values for histogram:
If I run:
barX = [h.get_height() for h in myPlot.patches]
Then I get the following list of lenght 11:
[0.046234272703757885,
0.1387028181112736,
0.346757045278184,
0.25428849987066837,
0.2542884998706682,
0.11558568175939472,
0.11875881712519201,
0.3087729245254993,
0.3087729245254993,
0.28502116110046083,
0.1662623439752689]
And this seems reasonable since there seems to be 6 values for the blue bars and 5 values for the red bars. But how do I tell which values belong to which variable?
Y-values for line:
This seems a bit easier than the histogram part since you can use myPlot.get_lines()[0].get_data()
AND myPlot.get_lines()[1].get_data()
to get:
Out[678]:
(array([-4.54448949, -4.47612134, -4.40775319, -4.33938504, -4.27101689,
...
3.65968859, 3.72805675, 3.7964249 , 3.86479305, 3.9331612 ,
4.00152935, 4.0698975 , 4.13826565]),
array([0.00042479, 0.00042363, 0.000473 , 0.00057404, 0.00073097,
0.00095075, 0.00124272, 0.00161819, 0.00208994, 0.00267162,
...
0.0033384 , 0.00252219, 0.00188591, 0.00139919, 0.00103544,
0.00077219, 0.00059125, 0.00047871]))
myPlot.get_lines()[1].get_data()
Out[679]:
(array([-3.68337423, -3.6256517 , -3.56792917, -3.51020664, -3.4524841 ,
-3.39476157, -3.33703904, -3.27931651, -3.22159398, -3.16387145,
...
3.24332952, 3.30105205, 3.35877458, 3.41649711, 3.47421965,
3.53194218, 3.58966471, 3.64738724]),
array([0.00035842, 0.00038018, 0.00044152, 0.00054508, 0.00069579,
0.00090076, 0.00116922, 0.00151242, 0.0019436 , 0.00247792,
...
0.00215912, 0.00163627, 0.00123281, 0.00092711, 0.00070127,
0.00054097, 0.00043517, 0.00037599]))
But the whole thing still seems a bit cumbersome. So does anyone know of a more direct approach to perhaps retrieve all data to a dictionary or dataframe?