3

I'm working with the titanic data and I'm trying to use a combination of pyplot and seaborn to produce some subplots. I've written the following code to create 6 subplots in a 3x2 grid;

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

_ = sns.catplot(x='Pclass', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 0])
_ = sns.catplot(x='Embarked', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 1])
_ = sns.catplot(x='Sex', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 0])
_ = sns.catplot(x='Sex', y='Age', hue='Pclass', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 1])
_ = sns.catplot(x='SibSp', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 0])
_ = sns.catplot(x='Parch', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 1])
plt.show()

When I run this in my notebook, it succesfully creates the desired plot, however, it also prints out 6 blank plots afterwards. enter image description here

How can I suppress these empty plots from printing into my output?

2 Answers2

0

Unlike other sns plots catplot generates a fig not an axes. That's why to fix such weird behavior you need to use plt.close() after each catplot execution:

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(8, 12))
fig.tight_layout()

sns.catplot(x='pclass', y='age', data=data, kind='box', ax=axes[0, 0])
plt.close()
sns.catplot(x='embarked', y='age', data=data, kind='box', ax=axes[0, 1])
plt.close()
sns.catplot(x='sex', y='age', data=data, kind='box', ax=axes[1, 0])
plt.close()
sns.catplot(x='sex', y='age', hue='pclass', data=data, kind='box', ax=axes[1, 1])
plt.close()
sns.catplot(x='sibsp', y='age', data=data, kind='box', ax=axes[2, 0])
plt.close()
sns.catplot(x='parch', y='age', data=data, kind='box', ax=axes[2, 1]);
plt.close()

plt.show()

Out:

catplots

trsvchn
  • 8,033
  • 3
  • 23
  • 30
0

Assign each of your plots to a variable like g, and use plt.close(g.fig) to remove your unwanted subplots. Or iterate over all sns.axisgrid.FacetGrid type variables and close them like so:

for p in plots_names:
    plt.close(vars()[p].fig)

The complete snippet below does just that. Note that I'm loading the titanic dataset using train_df = sns.load_dataset("titanic"). Here, all column names are lower case unlike in your example. I've also removed the palette=col_pal argument since col_pal is not defined in your snippet.

Plot:

enter image description here

Code:

import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

train_df = sns.load_dataset("titanic")

g = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
h = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
i = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
j = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
k = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
l = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])

# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
    plt.close(vars()[p].fig)

Please note that you will have to assign your plots to variable names for this to work. If you just add the snippet that closes the plots to the end of your original snippet, the duplicate subplot setup will remain untouched.

Code 2:

import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

train_df = sns.load_dataset("titanic")

_ = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
_ = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
_ = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
_ = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
_ = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
_ = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])

# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
    plt.close(vars()[p].fig)

Plot 2:

enter image description here

vestland
  • 55,229
  • 37
  • 187
  • 305