1

I want to create a subplot with two different joint-plots being merged horizontally, and thus I used plt.subplots(1, 2).

However, the result have 2 problems:

  1. Two unnecessary blank plots appeared at the top due to unknown reason, which I want to remove.
  2. The plots are currently merged vertically, instead of horizontally.

How can I modify my code to fix it? Thanks in advance!

import seaborn as sns
import numpy as np

sns.set(style="darkgrid")
iris = sns.load_dataset("iris")

fig, axes = plt.subplots(1, 2)
g = sns.jointplot(ax = axes[0], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_length', sizes=(10, 200), ax=g.ax_joint)

g = sns.jointplot(ax = axes[1], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_width', sizes=(10, 200), ax=g.ax_joint)

enter image description here

H42
  • 725
  • 2
  • 9
  • 28
  • 1
    `jointplot` is a [figure-level function](https://seaborn.pydata.org/tutorial/function_overview.html#figure-level-vs-axes-level-functions) so putting them in subplots is very hacky: [How to plot multiple seaborn jointplot in subplots](https://stackoverflow.com/q/35042255/13138364) – tdy Nov 04 '21 at 02:36

1 Answers1

2

From your question, I was able to first create each graph with a joint plot. I did a lot of research around turning that into a subplot, and found an inspiring answer here, which I applied. It brilliantly solved your problem. Thank you! @ImportanceOfBeingErnest

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

class SeabornFig2Grid():

    def __init__(self, seaborngrid, fig,  subplot_spec):
        self.fig = fig
        self.sg = seaborngrid
        self.subplot = subplot_spec
        if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
            isinstance(self.sg, sns.axisgrid.PairGrid):
            self._movegrid()
        elif isinstance(self.sg, sns.axisgrid.JointGrid):
            self._movejointgrid()
        self._finalize()

    def _movegrid(self):
        """ Move PairGrid or Facetgrid """
        self._resize()
        n = self.sg.axes.shape[0]
        m = self.sg.axes.shape[1]
        self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
        for i in range(n):
            for j in range(m):
                self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])

    def _movejointgrid(self):
        """ Move Jointgrid """
        h= self.sg.ax_joint.get_position().height
        h2= self.sg.ax_marg_x.get_position().height
        r = int(np.round(h/h2))
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)

        self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
        self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
        self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])

    def _moveaxes(self, ax, gs):
        #https://stackoverflow.com/a/46906599/4124317
        ax.remove()
        ax.figure=self.fig
        self.fig.axes.append(ax)
        self.fig.add_axes(ax)
        ax._subplotspec = gs
        ax.set_position(gs.get_position(self.fig))
        ax.set_subplotspec(gs)

    def _finalize(self):
        plt.close(self.sg.fig)
        self.fig.canvas.mpl_connect("resize_event", self._resize)
        self.fig.canvas.draw()

    def _resize(self, evt=None):
        self.sg.fig.set_size_inches(self.fig.get_size_inches())
        
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")

g0 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g0.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_length'], legend='brief')
g0.plot_marginals(sns.histplot, kde=True, color='k')

g1 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g1.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_width'], legend='brief')
g1.plot_marginals(sns.histplot, kde=True, color='k')


fig = plt.figure(figsize=(13,8))
gs = gridspec.GridSpec(1, 2)

mg0 = SeabornFig2Grid(g0, fig, gs[0])
mg1 = SeabornFig2Grid(g1, fig, gs[1])

gs.tight_layout(fig)

plt.show()

enter image description here

r-beginners
  • 31,170
  • 3
  • 14
  • 32
  • It was marked as a duplicate, but I can accept Tawashi's answer. Please give your vote to ImportanceOfBeggingErnest who has fundamentally solved your issue more than that. – r-beginners Nov 08 '21 at 03:49