0

I have the code below with randomly generated dataframes and I would like to extract the x and y values of both plotted lines. These line plots show the Price on the Y-axis and are Volume weighted.

For some reason, the line values for the second distribution plot, cannot be stored on the variables "df_2_x", "df_2_y". The values of "df_1_x", "df_1_y" are also written on the other variables. Both print statements return True, so the arrays are completely equal.

If I put them in separate cells in a notebook, it does work.

I also looked at this solution: How to retrieve all data from seaborn distribution plot with mutliple distributions?

But this does not work for weighted distplots.

import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt

Price_1 = [round(random.uniform(2,12), 2) for i in range(30)]
Volume_1 = [round(random.uniform(100,3000)) for i in range(30)]
Price_2 = [round(random.uniform(0,10), 2) for i in range(30)]
Volume_2 = [round(random.uniform(100,1500)) for i in range(30)]

df_1 = pd.DataFrame({'Price_1' : Price_1,
                    'Volume_1' : Volume_1})
df_2 = pd.DataFrame({'Price_2' : Price_2,
                    'Volume_2' :Volume_2})

df_1_x, df_1_y = sns.distplot(df_1.Price_1, hist_kws={"weights":list(df_1.Volume_1)}).get_lines()[0].get_data()
df_2_x, df_2_y = sns.distplot(df_2.Price_2, hist_kws={"weights":list(df_2.Volume_2)}).get_lines()[0].get_data()

print((df_1_x == df_2_x).all())
print((df_1_y == df_2_y).all())

Why does this happen, and how can I fix this?

JohanC
  • 71,591
  • 8
  • 33
  • 66

1 Answers1

1

Whether or not weight is used, doesn't make a difference here. The principal problem is that you are extracting again the first curve in df_2_x, df_2_y = sns.distplot(df_2....).get_lines()[0].get_data(). You'd want the second curve instead: df_2_x, df_2_y = sns.distplot(df_2....).get_lines()[1].get_data().

Note that seaborn isn't really meant to concatenate commands. Sometimes it works, but it usually adds a lot of confusion. E.g. sns.distplot returns an ax (which represents a subplot). Graphical elements such as lines are added to that ax.

Also note that sns.distplot has been deprecated. It will be removed from Seaborn in one of the next versions. It is replaced by sns.histplot and sns.kdeplot.

Here is how the code could look like:

import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt

Price_1 = [round(random.uniform(2, 12), 2) for i in range(30)]
Volume_1 = [round(random.uniform(100, 3000)) for i in range(30)]
Price_2 = [round(random.uniform(0, 10), 2) for i in range(30)]
Volume_2 = [round(random.uniform(100, 1500)) for i in range(30)]

df_1 = pd.DataFrame({'Price_1': Price_1,
                     'Volume_1': Volume_1})
df_2 = pd.DataFrame({'Price_2': Price_2,
                     'Volume_2': Volume_2})

ax = sns.histplot(x=df_1.Price_1, weights=list(df_1.Volume_1), bins=10, kde=True, kde_kws={'cut': 3})
sns.histplot(x=df_2.Price_2, weights=list(df_2.Volume_2), bins=10, kde=True, kde_kws={'cut': 3}, ax=ax)

df_1_x, df_1_y = ax.lines[0].get_data()
df_2_x, df_2_y = ax.lines[1].get_data()

# use fill_between to demonstrate where the extracted curves lie
ax.fill_between(df_1_x, 0, df_1_y, color='b', alpha=0.2)
ax.fill_between(df_2_x, 0, df_2_y, color='r', alpha=0.2)
plt.show()

extracting curves from sns.histplot

JohanC
  • 71,591
  • 8
  • 33
  • 66