1

At the moment I'm learning how to work with matplotlib and seaborn and the concept behind it seems quite strange to me. One would expect the sns.countplot function to return an object that has a .plot() and .save() fuction so one could work with the plot in a different function. Instead it seems that every call to sns.countplot overwrites the previous object (see MWE).

So one the one hand It would be grate if someone could provide a explanation of the matplotlib and seaborn interface (or have some good doku linked). Since all the doku I read wasn't of any great help.

On the other hand I have a function that returns some plots, which I want to save as an .pdf file with one plot per page. I found this similar question but can't copy the code over in a way to make my MWE work.

from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns


def generate_plots():

    penguins = sns.load_dataset("penguins")

    countplot_sex = sns.countplot(y='sex', data=penguins)
    countplot_species = sns.countplot(y='species', data=penguins)
    countplot_island = sns.countplot(y='island', data=penguins)

    # As showes
    # print(countplot_sex) -> AxesSubplot(0.125,0.11;0.775x0.77)
    # print(countplot_species) -> AxesSubplot(0.125,0.11;0.775x0.77)
    # print(countplot_island) -> AxesSubplot(0.125,0.11;0.775x0.77)
    # All three variables contain the same object

    return(countplot_sex, countplot_species, countplot_island)


def plots2pdf(plots, fname):  # from: https://stackoverflow.com/a/21489936
    pp = PdfPages('multipage.pdf')

    for plot in plots:
        pass
        # TODO save plot
        # Does not work: plot.savefig(pp, format='pdf')

    pp.savefig()
    pp.close()


def main():
    plots2pdf(generate_plots(), 'multipage.pdf')


if __name__ == '__main__':
    main()

My Idea is to have a somewhat decent software architecture with one function generating plots and another function saving them.

Someone2
  • 421
  • 2
  • 15

1 Answers1

2

The problem is that by default, sns.countplot will do its plotting on the current matplotlib Axes instance. From the docs:

ax matplotlib Axes, optional

Axes object to draw the plot onto, otherwise uses the current Axes.

One solution would be to define a small function that creates a new figure and Axes instance, then passes that to sns.countplot, to ensure it is plotted on a new figure and does not overwrite the previous one. This is what I have shown in the example below. An alternative would be to just create 3 figures and axes, and then pass each one to the sns.countplot function yourself.

Then in your plots2pdf function, you can iterate over the Axes, and pass their figure instances to the PdfPages instance when you save. (Note: Since you create the figures in the generate_plots function, an alternative would be to return the figure instances from that function, then you have them ready to pass into the pp.savefig function, but I did it this way so the output from your function remained the same).

from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import matplotlib.pyplot as plt

def generate_plots():

    penguins = sns.load_dataset("penguins")

    def my_countplot(y, data):
        fig, ax = plt.subplots()
        sns.countplot(y=y, data=data)
        return ax

    countplot_sex = my_countplot(y='sex', data=penguins)
    countplot_species = my_countplot(y='species', data=penguins)
    countplot_island = my_countplot(y='island', data=penguins)

    return(countplot_sex, countplot_species, countplot_island)


def plots2pdf(plots, fname):

    with PdfPages(fname) as pp:

        for plot in plots:

           pp.savefig(plot.figure)

def main():
    plots2pdf(generate_plots(), 'multipage.pdf')


if __name__ == '__main__':
    main()

A screenshot of the multipage pdf produced:

enter image description here

tmdavison
  • 64,360
  • 12
  • 187
  • 165
  • Thank you, since beyond the MWE I'm using other `sns` functions too I changed your function into: `def myplot(func): _, axsis = plt.subplots(); func(); return(axsis)` Which I can call via: `myplot(lambda: sns.heatmap(. . .))` – Someone2 Oct 09 '20 at 11:52