0

I have a function I've created to generate one or multiple plots depending on user inputs. The user can pass in a data frame and a list of the column names (categorySubset) to graph. The user can input a list with one column name or a list with several column names.

def generateMultidistribution(df, primaryDistribution, categoricalColumnName, categorySubset):
    
    # Initialize subplots
    fig,ax = plt.subplots(nrows=len(categorySubset), ncols=1)
    fig.suptitle('Line Item Distributions')

    # Specify subplot configuration
    fig.set_figheight(len(categorySubset) * 4.95)
    fig.set_figwidth(9.2)
    fig.tight_layout(pad = 5)
    graphIndex = 0
    
    for category in range(0, len(categorySubset)):

    ... # graphing code removed for brevity 

If I only pass in one argument for categorySubset, the function fails since subplots needs the number or rows (or columns) to be more than 1. Is there a good fix to allow this to work regardless of the number of columns (a.k.a. len(categorySubset))? Why does plt.subplots() require more than 1 row or column?

324
  • 702
  • 8
  • 28
  • `fig,axs = plt.subplots(...)` also works when there is only one subplot. Note that by default, `axs` will be a single `ax` when both `nrows` and `ncols` are 1. It will be one-dimensional when either `nrows` or `ncols` are 1 (and the other one larger than 1). It will be two-dimensional when both are larger than 1. You can use `fig,axs = plt.subplots(..., squeeze=False)` for `axs` always to be two-dimensional. – JohanC Nov 26 '22 at 17:18

0 Answers0