Similar to here, I try to create a loop that generates a figure with subplots from predefined functions. Those functions create different kind of figures (like line-plots or tables) and already use plt.subplots
. In the end, I want to create a figure with multiple subplots for every country in my dataset through a loop. The country specific figures shall then be saved on individual pages of a pdf file.
import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec
dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
'year': [2006,2007,2008,2006,2007,2008],
'gdp': [10,13,7,8,2,10],
'empowerment': [0.2,0.13,0.7,0.8,0.2,0.10],
'solidarity': [0.4,0.63,0.3,0.66,0.85,0.9],
'envir': [55,34,79,65,59,88]})
The functions are constructed as follows:
def prepare_line(countries):
plt.close()
select_country = dataset[dataset.country == countries]
select_country = select_country.round(4)
# create figure and axis objects with subplots()
fig, ax = plt.subplots(figsize=(12, 5))
# Line plots
ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",
label="Empowerment Index")
ind2 = ax.plot(select_country.year, select_country.solidarity, color="red",
label="Solidarity Index")
# set x-axis label
ax.set_xlabel("year", fontsize=14)
# set y-axis label
ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)
ax2 = ax.twinx()
axes = plt.gca()
axes.yaxis.grid()
# make a plot with different y-axis using second axis object
ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green",
label="GDP per Capita (const. US 2010)")
ax2.set_ylabel("GDP per Capita", fontsize=14)
plt.title(countries, fontsize=18)
plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
# add figures to get labels
ind = ind1 + ind2 + ind3
labs = [l.get_label() for l in ind]
# define location of legend
ax.legend(ind, labs, loc=2)
return fig
and
def prepare_table(countries):
select_country = dataset[dataset.country == countries]
data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])
decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
data_round = data_table.round(decimals)
data_round['gdp'] = data_round['gdp'].astype('Int64')
data_round['envir'] = data_round['envir'].astype('Int64')
data_round = pd.DataFrame(data_round)
data_round = data_round.fillna(0)
plt.figure()
# table
fig_table = plt.table(cellText=data_round.values,
colLabels=['Year', 'Emp. Ind.',
'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'],
loc='center')
fig_table.auto_set_font_size(False)
fig_table.set_fontsize(10)
fig_table.scale(1.8, 1.5)
plt.axis('off')
return fig_table
To generate a pdf with individual pages for separate countries that contain the figures generated from above functions as subplots, I use the following:
!pip import simplejson as json
import webbrowser
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec
country = dataset.country.unique()
if __name__ == "__main__":
with PdfPages('Summary.pdf') as pdf:
for i in country:
gs = gridspec.GridSpec(2, 2)
ax1 = prepare_line(i)
ax2 = prepare_table(i)
pdf.savefig(gs)
Unfortunately, neither the lineplot nor the table are saved in the pdf file but only generated iteratively for all countries.
I tried several other configurations including constructions where an 'ax' argument is included in the individual functions. Any help is greatly appreciated. Sorry for the messy functions.
############################################################################
Edit: The solution by @gepcel works fine for the above. However, I run into a problem while trying to embed a radar graph as subplot. The function I use for the radar graph is as follows:
def prepare_spider(ax, countries):
#plt.close()
select_country = spider_data.loc[(spider_data['country'] == countries)]
base = select_country.replace({'year': {3000:"Baseline 2009"}})
year = base.year
data_legend = list(year)
data_red = DataFrame(select_country, columns=['spidersolidarity', 'spidergdp', 'spiderempowerment', 'spiderenvir'])
data = data_red.values.tolist()
N = 4
theta = radar_factory(N, frame='polygon')
spoke_labels = ["Solidarity", "GDP", "Agency", "EPI"]
title = countries
plt.sca(ax)
fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
fig_spider.subplots_adjust(top=0.85, bottom=0.05)
ax.set_ylim(-3, 3)
ax.set_rgrids([])
ax.set_title(title, position=(0.5, 1.1), ha='center')
for d in data:
line = ax.plot(theta, d)
ax.set_varlabels(spoke_labels)
labels = (data_legend)
legend = ax.legend(labels, loc=(0.8, .08),
labelspacing=0.1, fontsize='small')
return fig
Specifically, I do not see where to embed plt.sca(ax)
with fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
also being contained.
The function calls radar_factory(num_vars, frame='circle')
as defined here and uses normalized data like this:
spider_data = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
'year': [2009,2019,3000,2009,2019,3000],
'spiderempowerment': [0.5,0.6,0.7,0.8,0.2,0.10],
'spidersolidarity': [0.4,0.63,0.3,0.66,0.85,0.9],
'spidergdp': [0.10,0.13,0.7,0.8,0.2,0.10],
'spiderenvir': [0.55,0.34,0.79,0.65,0.59,0.88]})
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
def radar_factory(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Parameters
----------
num_vars : int
Number of variables for radar chart.
frame : {'circle' | 'polygon'}
Shape of frame surrounding axes.
"""
# calculate evenly-spaced axis angles
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
class RadarAxes(PolarAxes):
name = 'radar'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# rotate plot such that the first axis is at the top
self.set_theta_zero_location('N')
def fill(self, *args, closed=True, **kwargs):
"""Override fill so that line is closed by default"""
return super().fill(closed=closed, *args, **kwargs)
def plot(self, *args, **kwargs):
"""Override plot so that line is closed by default"""
lines = super().plot(*args, **kwargs)
for line in lines:
self._close_line(line)
def _close_line(self, line):
x, y = line.get_data()
# FIXME: markers at x[0], y[0] get doubled-up
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, labels):
self.set_thetagrids(np.degrees(theta), labels)
def _gen_axes_patch(self):
# The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
# in axes coordinates.
if frame == 'circle':
return Circle((0.5, 0.5), 0.5)
elif frame == 'polygon':
return RegularPolygon((0.5, 0.5), num_vars,
radius=.5, edgecolor="k")
else:
raise ValueError("unknown value for 'frame': %s" % frame)
def draw(self, renderer):
""" Draw. If frame is polygon, make gridlines polygon-shaped """
if frame == 'polygon':
gridlines = self.yaxis.get_gridlines()
for gl in gridlines:
gl.get_path()._interpolation_steps = num_vars
super().draw(renderer)
def _gen_axes_spines(self):
if frame == 'circle':
return super()._gen_axes_spines()
elif frame == 'polygon':
# spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
spine = Spine(axes=self,
spine_type='circle',
path=Path.unit_regular_polygon(num_vars))
# unit_regular_polygon gives a polygon of radius 1 centered at
# (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
# 0.5) in axes coordinates.
spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
+ self.transAxes)
return {'polar': spine}
else:
raise ValueError("unknown value for 'frame': %s" % frame)
register_projection(RadarAxes)
return theta