0

Is there a way to add additional subplots created with vanilla Matplotlib to (below) a Seaborn jointplot, sharing the x-axis? Ideally I'd like to control the ratio between the jointplot and the additional plots (similar to gridspec_kw={'height_ratios':[3, 1, 1]}

I tried to fake it by tuning figsize in the Matplotlib subplots, but obviously it doesn't work well when the KDE curves in the marginal plot change. While I could manually resize the output PNG to shrink/grow one of the figures, I'd like to have everything aligned automatically.

I know this is tricky with the way the joint grid is set up, but maybe it is reasonably simple for someone fluent in the underpinnings of Seaborn.

Here is a minimal working example, but there are two separate figures:

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

Figure 1

diamonds = sns.load_dataset('diamonds')
g = sns.jointplot(
    data=diamonds,
    x="carat",
    y="price",
    hue="cut",
    xlim=(1, 2),
)
g.ax_marg_x.remove()

enter image description here

Figure 2

fig, (ax1, ax2) = plt.subplots(2,1,sharex=True)
ax1.scatter(x=diamonds["carat"], y=diamonds["depth"], color="gray", edgecolor="black")
ax1.set_xlim([1, 2])
ax1.set_ylabel("depth")
ax2.scatter(x=diamonds["carat"], y=diamonds["table"], color="gray", edgecolor="black")
ax2.set_xlabel("carat")
ax2.set_ylabel("table")

enter image description here

Desired output:

enter image description here

a11
  • 3,122
  • 4
  • 27
  • 66
  • 1
    The short answer is no. See [seaborn.JointGrid](https://seaborn.pydata.org/generated/seaborn.JointGrid.html). You will probably need to build it yourself, using GridSpec and subplots. See [Matplotlib different size subplots](https://stackoverflow.com/q/10388462/7758804) – Trenton McKinney Oct 01 '21 at 00:37

2 Answers2

1

I think this is a case where setting up the figure using matplotlib functions is going to be better than working backwards from a seaborn figure layout that doesn't really match the use-case.

If you have a non-full subplot grid, you'll have to decide whether you want to (A) set up all the subplots and then remove the ones you don't want or (B) explicitly add each of the subplots you do want. Let's go with option A here.

figsize = (6, 8)
gridspec_kw = dict(
    nrows=3, ncols=2,
    width_ratios=[5, 1],
    height_ratios=[4, 1, 1],
)
subplot_kw = dict(sharex="col", sharey="row")
fig = plt.figure(figsize=figsize, constrained_layout=True)
axs = fig.add_gridspec(**gridspec_kw).subplots(**subplot_kw)

sns.kdeplot(data=df, y="price", hue="cut", legend=False, ax=axs[0, 1])
sns.scatterplot(data=df, x="carat", y="price", hue="cut", ax=axs[0, 0])
sns.scatterplot(data=df, x="carat", y="depth", color=".2", ax=axs[1, 0])
sns.scatterplot(data=df, x="carat", y="table", color=".2", ax=axs[2, 0])

axs[0, 0].set(xlim=(1, 2))

axs[1, 1].remove()
axs[2, 1].remove()

enter image description here

BTW, this is almost a bit easier with plt.subplot_mosaic, but it does not yet support axis sharing.

mwaskom
  • 46,693
  • 16
  • 125
  • 127
  • Wouldn't you also need `hue_order` or so to make sure the scatterplot and kdeplot use the hue with the same colors for the same values? – JohanC Oct 03 '21 at 17:21
  • No, they're both consuming the full vector of hues so they'll have the same default mapping. Where you run into trouble is when different facets get subsets of a dataset, which might have values appearing in different orders. – mwaskom Oct 03 '21 at 17:29
0

You could take the figure created by jointplot(), move its padding (with subplots_adjust()) and add 2 extra axes.

The example code will need some tweaking for each particular situation.

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns

diamonds = sns.load_dataset('diamonds')
g = sns.jointplot(data=diamonds, x="carat", y="price", hue="cut",
                  xlim=(1, 2), height=12)
g.ax_marg_x.remove()
g.fig.subplots_adjust(left=0.08, right=0.97, top=1.05, bottom=0.45)

axins1 = inset_axes(g.ax_joint, width="100%", height="30%",
                    bbox_to_anchor=(0, -0.4, 1, 1),
                    bbox_transform=g.ax_joint.transAxes, loc=3, borderpad=0)
axins2 = inset_axes(g.ax_joint, width="100%", height="30%",
                    bbox_to_anchor=(0, -0.75, 1, 1),
                    bbox_transform=g.ax_joint.transAxes, loc=3, borderpad=0)
shared_x_group = g.ax_joint.get_shared_x_axes()
shared_x_group.remove(g.ax_marg_x)
shared_x_group.join(g.ax_joint, axins1)
shared_x_group.join(g.ax_joint, axins2)

axins1.scatter(x=diamonds["carat"], y=diamonds["depth"], color="grey", edgecolor="black")
axins1.set_ylabel("depth")
axins2.scatter(x=diamonds["carat"], y=diamonds["table"], color="grey", edgecolor="black")
axins2.set_xlabel("carat")
axins2.set_ylabel("table")
g.ax_joint.set_xlim(1, 2)
plt.setp(axins1.get_xticklabels(), visible=False)
plt.show()

sns.jointplot with extra subplots

PS: How to share x axes of two subplots after they have been created contains some info about sharing axes (although here you simply get the same effect by setting the xlims for each of the subplots).

The code to position the new axes has been adapted from this tutorial example.

JohanC
  • 71,591
  • 8
  • 33
  • 66