1

Similar to here, I try to create a loop that generates a figure with subplots from predefined functions. Those functions create different kind of figures (like line-plots or tables) and already use plt.subplots. In the end, I want to create a figure with multiple subplots for every country in my dataset through a loop. The country specific figures shall then be saved on individual pages of a pdf file.

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

The functions are constructed as follows:

def prepare_line(countries):
    plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    fig, ax = plt.subplots(figsize=(12, 5))

    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig

and

def prepare_table(countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)

    plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table

To generate a pdf with individual pages for separate countries that contain the figures generated from above functions as subplots, I use the following:

!pip import simplejson as json
import webbrowser

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec

country = dataset.country.unique()


if __name__ == "__main__":
  
  with PdfPages('Summary.pdf') as pdf:
    
    for i in country:
      gs = gridspec.GridSpec(2, 2)
      ax1 = prepare_line(i)
      ax2 = prepare_table(i) 
    pdf.savefig(gs)

Unfortunately, neither the lineplot nor the table are saved in the pdf file but only generated iteratively for all countries.

I tried several other configurations including constructions where an 'ax' argument is included in the individual functions. Any help is greatly appreciated. Sorry for the messy functions.

############################################################################

Edit: The solution by @gepcel works fine for the above. However, I run into a problem while trying to embed a radar graph as subplot. The function I use for the radar graph is as follows:

def prepare_spider(ax, countries):
    #plt.close()
    select_country = spider_data.loc[(spider_data['country'] == countries)]
    base = select_country.replace({'year': {3000:"Baseline 2009"}})
    year = base.year
    data_legend = list(year)

    data_red = DataFrame(select_country, columns=['spidersolidarity',  'spidergdp', 'spiderempowerment', 'spiderenvir'])
    data = data_red.values.tolist()

    N = 4
    theta = radar_factory(N, frame='polygon')

    spoke_labels = ["Solidarity", "GDP", "Agency", "EPI"]
    title = countries

    plt.sca(ax)

    fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
    fig_spider.subplots_adjust(top=0.85, bottom=0.05)

    ax.set_ylim(-3, 3)
    ax.set_rgrids([])
    ax.set_title(title, position=(0.5, 1.1), ha='center')

    for d in data:
        line = ax.plot(theta, d)
    ax.set_varlabels(spoke_labels)

    labels = (data_legend)
    legend = ax.legend(labels, loc=(0.8, .08),
                       labelspacing=0.1, fontsize='small')
    
    return fig 

Specifically, I do not see where to embed plt.sca(ax) with fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar')) also being contained.

The function calls radar_factory(num_vars, frame='circle') as defined here and uses normalized data like this:

spider_data = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                            'year': [2009,2019,3000,2009,2019,3000],
                            'spiderempowerment':   [0.5,0.6,0.7,0.8,0.2,0.10],
                            'spidersolidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                            'spidergdp':   [0.10,0.13,0.7,0.8,0.2,0.10],
                            'spiderenvir':   [0.55,0.34,0.79,0.65,0.59,0.88]})

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarAxes(PolarAxes):

        name = 'radar'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def draw(self, renderer):
            """ Draw. If frame is polygon, make gridlines polygon-shaped """
            if frame == 'polygon':
                gridlines = self.yaxis.get_gridlines()
                for gl in gridlines:
                    gl.get_path()._interpolation_steps = num_vars
            super().draw(renderer)


        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)


                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta
lomi
  • 13
  • 3

1 Answers1

0

There are numerous mistakes in codes. PdfPages saves one figure per page. So you should generate one figure per country, with two axes (line plot and table). Full codes are as following:

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

def prepare_line(ax, countries):
    # plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    # fig, ax = plt.subplots(figsize=(12, 5))
    plt.sca(ax)
    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    # plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig


def prepare_table(ax, countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)
    plt.sca(ax)
    # plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table


country = dataset.country.unique()

with PdfPages('Summary.pdf') as pdf:

    for i in country:
        fig, axs = plt.subplots(2, 1, figsize=(12, 9))
        ax1 = prepare_line(axs[0], i)
        ax2 = prepare_table(axs[1], i) 
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close()

Edit: add the radar chart. Note that the example code I gave is just the minimal modification that works based on the combination of your code. It should be optimized.

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

def prepare_line(ax, countries):
    # plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    # fig, ax = plt.subplots(figsize=(12, 5))
    plt.sca(ax)
    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    # plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig


def prepare_table(ax, countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)
    plt.sca(ax)
    # plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table


spider_data = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                            'year': [2009,2019,3000,2009,2019,3000],
                            'spiderempowerment':   [0.5,0.6,0.7,0.8,0.2,0.10],
                            'spidersolidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                            'spidergdp':   [0.10,0.13,0.7,0.8,0.2,0.10],
                            'spiderenvir':   [0.55,0.34,0.79,0.65,0.59,0.88]})

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarAxes(PolarAxes):

        name = 'radar'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def draw(self, renderer):
            """ Draw. If frame is polygon, make gridlines polygon-shaped """
            if frame == 'polygon':
                gridlines = self.yaxis.get_gridlines()
                for gl in gridlines:
                    gl.get_path()._interpolation_steps = num_vars
            super().draw(renderer)


        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)


                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta


def prepare_spider(ax, countries):
    #plt.close()
    select_country = spider_data.loc[(spider_data['country'] == countries)]
    base = select_country.replace({'year': {3000:"Baseline 2009"}})
    year = base.year
    data_legend = list(year)

    data_red = DataFrame(select_country, columns=['spidersolidarity',  'spidergdp', 'spiderempowerment', 'spiderenvir'])
    data = data_red.values.tolist()

    N = 4
    theta = radar_factory(N, frame='polygon')

    spoke_labels = ["Solidarity", "GDP", "Agency", "EPI"]
    title = countries

    # plt.sca(ax)
    ax.set_ylim(-3, 3)
    ax.set_rgrids([])
    ax.set_title(title, position=(0.5, 1.1), ha='center')

    for d in data:
        line = ax.plot(theta, d)
    ax.set_varlabels(spoke_labels)

    labels = (data_legend)
    legend = ax.legend(labels, loc=(0.8, .08),
                       labelspacing=0.1, fontsize='small')
    
    return ax

country = dataset.country.unique()

with PdfPages('Summary.pdf') as pdf:

    for i in country:
        fig = plt.figure(figsize=(12, 12))
        ax1 = plt.subplot(311)
        ax1 = prepare_line(ax1, i)
        ax2 = plt.subplot(312)
        ax2 = prepare_table(ax2, i) 
        ax3 = plt.subplot(313, projection='radar')
        prepare_spider(ax3, i)
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close()
gepcel
  • 1,326
  • 11
  • 21
  • Thanks so much! This works just fine for the above. However I run into problem while trying to embed another function for calling a radar plot into this framework. – lomi Jan 06 '21 at 13:11
  • Can you provide any example? – gepcel Jan 06 '21 at 13:20
  • The line plot, table, radar, do you need them on the same page or one plot per page? – gepcel Jan 07 '21 at 00:59
  • As to the `plt.sca()`, it's not a good idea, just the easiest modification based on your original code. Difference between `pyplot` and `object-oriented API`, you can see [here](https://matplotlib.org/tutorials/introductory/pyplot.html#sphx-glr-tutorials-introductory-pyplot-py) – gepcel Jan 07 '21 at 01:03
  • Yes, I will need plot, table and radar on the same page. Thanks for the link. – lomi Jan 07 '21 at 09:30