2

How to scale marginal kdeplot of seaborn jointplot?

Let's imagine that we have 1000 datum of kind 'a', 100 datum of kind 'b', and '100' datum of kind 'c'.

In this case, the marginal kdeplot's scale doesn't seem identical because the size of categorical data is quite different.

How do I make these identical?

I make a toy script like below:

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

ax, ay = 1 * np.random.randn(1000) + 2, 1 * np.random.randn(1000) + 2
bx, by = 1 * np.random.randn(100) + 3, 1 * np.random.randn(100) + 3
cx, cy = 1 * np.random.randn(100) + 4, 1 * np.random.randn(100) + 4

a = [{'x': x, 'y': y, 'kind': 'a'} for x, y in zip(ax, ay)]
b = [{'x': x, 'y': y, 'kind': 'b'} for x, y in zip(bx, by)]
c = [{'x': x, 'y': y, 'kind': 'c'} for x, y in zip(cx, cy)]

df = pd.concat([pd.DataFrame.from_dict(a), pd.DataFrame.from_dict(b), pd.DataFrame.from_dict(c)], ignore_index=True)

print(df)
             x         y kind
0     2.500866  2.700925    a
1    -0.386057  3.322318    a
2     1.691078  2.558366    a
3     2.235042 -0.113836    a
4     3.331039  1.138366    a
...        ...       ...  ...
1195  3.703245  2.935332    c
1196  1.806040  2.842754    c
1197  5.431313  5.377297    c
1198  3.873162  6.200356    c
1199  4.111234  3.038126    c

[1200 rows x 3 columns]

sns.jointplot(data=df, x='x', y='y', hue="kind")
plt.show()

enter image description here

Hyunseung Kim
  • 493
  • 1
  • 6
  • 17

2 Answers2

4

You can use marginal_kws= to add keywords for the marginal plots. In this case, the marginals use sns.kdeplot which has parameters such as commmon_norm and multiple.

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

ax, ay = 1 * np.random.randn(1000) + 2, 1 * np.random.randn(1000) + 2
bx, by = 1 * np.random.randn(100) + 3, 1 * np.random.randn(100) + 3
cx, cy = 1 * np.random.randn(100) + 4, 1 * np.random.randn(100) + 4

a = [{'x': x, 'y': y, 'kind': 'a'} for x, y in zip(ax, ay)]
b = [{'x': x, 'y': y, 'kind': 'b'} for x, y in zip(bx, by)]
c = [{'x': x, 'y': y, 'kind': 'c'} for x, y in zip(cx, cy)]

df = pd.concat([pd.DataFrame.from_dict(a), pd.DataFrame.from_dict(b), pd.DataFrame.from_dict(c)], ignore_index=True)

sns.jointplot(data=df, x='x', y='y', hue="kind" , marginal_kws={'common_norm':False})
plt.show()

sns.jointplot without common norm

JohanC
  • 71,591
  • 8
  • 33
  • 66
1

You can use sns.JointGrid and manually normalize histogram for each class. With this approach, you have a complete control over the axis which is useful for many use cases.

g = sns.JointGrid(data=df, x='x', y='y', hue='kind')
g.plot_joint(sns.scatterplot)

for k in df.kind.unique():
    data = df[df.kind == k]
    sns.kdeplot(x=data.x, fill=True, label=k, ax=g.ax_marg_x)
    sns.kdeplot(y=data.y, fill=True, label=k, ax=g.ax_marg_y)

plt.show()

enter image description here

Erik Hulmák
  • 133
  • 1
  • 7