0

This is related to this question

I'm trying to add in the color bar legends to a seaborn jointplot to the overal SeabornFig2Grid. Here is the code and output, as you can see the color bars are overlapping the 2d histograms, and the right label of the color bar is cut off.

I've tried using plt.subplots_adjust, or ax.ax_joint.set_position / ax.set_position, but with no luck.

I've updated SeabornFig2Grid from the original question to include the method _moveaxis so that the color bar could be added.

here is the code

import matplotlib
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec

class SeabornFig2Grid():
    """Allow seaborn figure-level figs to be suplots.
    
    thanks - https://stackoverflow.com/questions/35042255/how-to-plot-multiple-seaborn-jointplot-in-subplot/47664533#47664533
    """
    
    def __init__(self, seaborngrid, fig,  subplot_spec):
        
        self.fig = fig
        self.sg = seaborngrid
        self.subplot = subplot_spec
        if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
            isinstance(self.sg, sns.axisgrid.PairGrid):
            self._movegrid()
        elif isinstance(self.sg, sns.axisgrid.JointGrid):
            self._movejointgrid()
        elif isinstance(self.sg, matplotlib.axes._axes.Axes):
            self._moveaxis()
        self._finalize()
        pass
    
    def _moveaxis(self):
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=self.subplot)
        self._moveaxes(self.sg, self.subgrid[0, 0])

    def _movegrid(self):
        """Move PairGrid or Facetgrid."""
        self._resize()
        n = self.sg.axes.shape[0]
        m = self.sg.axes.shape[1]
        self.subgrid = gridspec.GridSpecFromSubplotSpec(n, m, subplot_spec=self.subplot)
        for i in range(n):
            for j in range(m):
                self._moveaxes(self.sg.axes[i, j], self.subgrid[i, j])
        pass

    def _movejointgrid(self):
        """Move Jointgrid."""
        h = self.sg.ax_joint.get_position().height
        h2 = self.sg.ax_marg_x.get_position().height
        r = int(np.round(h / h2))
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(r + 1, r + 1, subplot_spec=self.subplot)

        self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
        self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
        self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])
        pass

    def _moveaxes(self, ax, grid_spec):

        ax.remove()
        ax.figure = self.fig
        self.fig.axes.append(ax)
        self.fig.add_axes(ax)
        ax._subplotspec = grid_spec
        ax.set_position(grid_spec.get_position(self.fig))
        try:
            ax.set_subplotspec(grid_spec)
        except AttributeError:
            ax._subplotspec = grid_spec
        pass

    def _finalize(self):
        try:
            plt.close(self.sg.fig)
        except AttributeError:
            pass
        self.fig.canvas.mpl_connect("resize_event", self._resize)
        self.fig.canvas.draw()
        pass

    def _resize(self, evt=None):
        self.sg.figure.set_size_inches(self.fig.get_size_inches())
        pass
    
    pass


# create data
x = [np.random.random() for x in range(1000)]
y = [np.random.random() for x in range(1000)]

# add 2d-hist and color bars to `axes` list
axes = []
for i in range(2):
    # create 2d histogram
    ax1 = sns.jointplot(x=x, y=y, marginal_kws={'bins' : 20})
    ax1.ax_joint.cla()
    ax1.fig.set_size_inches((5, 4))
    plt.sca(ax1.ax_joint)
    plt.hist2d(x, y, bins=20, norm=mcolors.LogNorm(*(None, None)), cmap='jet')

    # set up scale bar legend
    cbar_ax = ax1.fig.add_axes([1, 0.1, 0.03, 0.7])
    cb = plt.colorbar(cax=cbar_ax)
    cb.set_label(r"$\log_{10}$ density of points", fontsize=13)

    axes.extend([ax1, cbar_ax])

# adjust gridspec so ratios are more visually appealing
fig = plt.figure(figsize=(7, 3))
gs = gridspec.GridSpec(1, len(axes), width_ratios=[1, 0.1] * int(len(axes) / 2))

# add axes to SeabornFig2Grid
for ax, g in zip(axes, gs):
    _ = SeabornFig2Grid(ax, fig, g)
    
plt.show() 

enter image description here

BML
  • 191
  • 2
  • 12
  • What is the `h` in `_moveaxis` supposed to be used for? – jared Jul 26 '23 at 16:41
  • sorry @jared, that was a line of code that i forgot to remove. I've removed it. (It was related to the `_movejointgrid` when I was building `_moveaxis`, but I didn't need it) – BML Jul 26 '23 at 16:43
  • Can you get this to work with just one plot in a figure? If so, you could use subfigures: https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html – Jody Klymak Jul 26 '23 at 17:11
  • @JodyKlymak no unfortunately this won't work with a seaborn figure. You can't pass an axes subplot to sns.jointplot. – BML Jul 26 '23 at 18:36
  • Can you pass a figure? – Jody Klymak Jul 26 '23 at 23:31
  • @JodyKlymak not that I can tell from documentation. https://seaborn.pydata.org/generated/seaborn.jointplot.html – BML Jul 27 '23 at 20:50
  • This seems like a pretty straightforward thing for seaborn to add – Jody Klymak Jul 30 '23 at 14:33
  • I also posted the answer here(https://stackoverflow.com/a/70816718/18004774), but I believe the issue can be resolved by using the patchworklib I developed previously. How about that? – Hideto Aug 02 '23 at 01:47

0 Answers0