One (admittedly not particularly neat solution), is to define your own FacetGrid
class (see the code here), that takes in a fig
argument, so that you can pass it subfigures. E.g.:
from itertools import product
import warnings
import numpy as np
import matplotlib.pyplot as plt
from seaborn.axisgrid import FacetGrid, Grid
from seaborn._oldcore import categorical_order
from seaborn.utils import _disable_autolayout
class FacetGridWithFigure(FacetGrid):
def __init__(
self, data, *,
row=None, col=None, hue=None, col_wrap=None,
sharex=True, sharey=True, height=3, aspect=1, palette=None,
row_order=None, col_order=None, hue_order=None, hue_kws=None,
dropna=False, legend_out=True, despine=True,
margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
gridspec_kws=None, fig=None, # additional fig argument
):
# make sure to init the parent of FacetGrid
super(FacetGrid, self).__init__()
# Determine the hue facet layer information
hue_var = hue
if hue is None:
hue_names = None
else:
hue_names = categorical_order(data[hue], hue_order)
colors = self._get_palette(data, hue, hue_order, palette)
# Set up the lists of names for the row and column facet variables
if row is None:
row_names = []
else:
row_names = categorical_order(data[row], row_order)
if col is None:
col_names = []
else:
col_names = categorical_order(data[col], col_order)
# Additional dict of kwarg -> list of values for mapping the hue var
hue_kws = hue_kws if hue_kws is not None else {}
# Make a boolean mask that is True anywhere there is an NA
# value in one of the faceting variables, but only if dropna is True
none_na = np.zeros(len(data), bool)
if dropna:
row_na = none_na if row is None else data[row].isnull()
col_na = none_na if col is None else data[col].isnull()
hue_na = none_na if hue is None else data[hue].isnull()
not_na = ~(row_na | col_na | hue_na)
else:
not_na = ~none_na
# Compute the grid shape
ncol = 1 if col is None else len(col_names)
nrow = 1 if row is None else len(row_names)
self._n_facets = ncol * nrow
self._col_wrap = col_wrap
if col_wrap is not None:
if row is not None:
err = "Cannot use `row` and `col_wrap` together."
raise ValueError(err)
ncol = col_wrap
nrow = int(np.ceil(len(col_names) / col_wrap))
self._ncol = ncol
self._nrow = nrow
# Calculate the base figure size
# This can get stretched later by a legend
# TODO this doesn't account for axis labels
figsize = (ncol * height * aspect, nrow * height)
# Validate some inputs
if col_wrap is not None:
margin_titles = False
# Build the subplot keyword dictionary
subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
if xlim is not None:
subplot_kws["xlim"] = xlim
if ylim is not None:
subplot_kws["ylim"] = ylim
# --- Initialize the subplot grid
# create figure if one not given as argument
if fig is None:
with _disable_autolayout():
fig = plt.figure(figsize=figsize)
if col_wrap is None:
kwargs = dict(squeeze=False,
sharex=sharex, sharey=sharey,
subplot_kw=subplot_kws,
gridspec_kw=gridspec_kws)
axes = fig.subplots(nrow, ncol, **kwargs)
if col is None and row is None:
axes_dict = {}
elif col is None:
axes_dict = dict(zip(row_names, axes.flat))
elif row is None:
axes_dict = dict(zip(col_names, axes.flat))
else:
facet_product = product(row_names, col_names)
axes_dict = dict(zip(facet_product, axes.flat))
else:
# If wrapping the col variable we need to make the grid ourselves
if gridspec_kws:
warnings.warn("`gridspec_kws` ignored when using `col_wrap`")
n_axes = len(col_names)
axes = np.empty(n_axes, object)
axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
if sharex:
subplot_kws["sharex"] = axes[0]
if sharey:
subplot_kws["sharey"] = axes[0]
for i in range(1, n_axes):
axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)
axes_dict = dict(zip(col_names, axes))
# --- Set up the class attributes
# Attributes that are part of the public API but accessed through
# a property so that Sphinx adds them to the auto class doc
self._figure = fig
self._axes = axes
self._axes_dict = axes_dict
self._legend = None
# Public attributes that aren't explicitly documented
# (It's not obvious that having them be public was a good idea)
self.data = data
self.row_names = row_names
self.col_names = col_names
self.hue_names = hue_names
self.hue_kws = hue_kws
# Next the private variables
self._nrow = nrow
self._row_var = row
self._ncol = ncol
self._col_var = col
self._margin_titles = margin_titles
self._margin_titles_texts = []
self._col_wrap = col_wrap
self._hue_var = hue_var
self._colors = colors
self._legend_out = legend_out
self._legend_data = {}
self._x_var = None
self._y_var = None
self._sharex = sharex
self._sharey = sharey
self._dropna = dropna
self._not_na = not_na
# --- Make the axes look good
self.set_titles()
self.tight_layout()
if despine:
self.despine()
if sharex in [True, 'col']:
for ax in self._not_bottom_axes:
for label in ax.get_xticklabels():
label.set_visible(False)
ax.xaxis.offsetText.set_visible(False)
ax.xaxis.label.set_visible(False)
if sharey in [True, 'row']:
for ax in self._not_left_axes:
for label in ax.get_yticklabels():
label.set_visible(False)
ax.yaxis.offsetText.set_visible(False)
ax.yaxis.label.set_visible(False)
def tight_layout(self):
# subfigures don't have a tight layout option
pass
You can then do, e.g.:
import seaborn as sns
fig = plt.figure(figsize=(10, 4))
# create subfigs
subfigs = fig.subfigures(1, 2, wspace=0.07)
g1 = FacetGridWithFigure(tips, col="time", row="sex", fig=subfigs[0])
g1.map(sns.scatterplot, "total_bill", "tip")
g2 = FacetGridWithFigure(tips, col="size", height=2.5, col_wrap=3, fig=subfigs[1])
g2.map(sns.histplot, "total_bill")
plt.show()
which gives:
