2

I have a made a function which plots input variables against predicted variables.

dummy_data = pd.DataFrame(np.random.uniform(low=65.5,high=140.5,size=(50,4)), columns=list('ABCD'))
dummy_predicted = pd.DataFrame(np.random.uniform(low=15.5,high=17.5,size=(50,4)), columns=list('WXYZ'))

##Plot test input distriubtions
fig = plt.figure(figsize=(15,6))
n_rows = 1 
n_cols = 4
counter = 1
for i in dummy_data.keys():
    plt.subplot(n_rows, n_cols, counter)
    plt.scatter(dummy_data[i], dummy_predicted['Z'])

    plt.title(f'{i} vs Z')
    plt.xlabel(i)
    counter += 1

plt.tight_layout() 
plt.show()

enter image description here

How do I create a 4 x 4 subplot of all combinations of 'ABCD' and 'WXYZ'? I can have any number of dummy_data and dummy_predicted columns so some dynamism would be useful.

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
Zizi96
  • 459
  • 1
  • 6
  • 23

2 Answers2

3
  • Use itertools.product from the standard library, to create all combinations of column names, combos.
  • Use the len of each set of columns to determine nrows and ncols for plt.subplots
  • Flatten the array of axes to easily iterate through a 1D array instead of a 2D array.
  • zip combos and axes to iterate through, and plot each group with a single loop.
  • See this answer in How to plot in multiple subplots.
from itertools import product
import matplotlib.pyplot as plt
import numpy as np

# sample data
np.random.seed(2022)
dd = pd.DataFrame(np.random.uniform(low=65.5, high=140.5, size=(50, 4)), columns=list('ABCD'))
dp = pd.DataFrame(np.random.uniform(low=15.5, high=17.5, size=(50, 4)), columns=list('WXYZ'))

# create combinations of columns
combos = product(dd.columns, dp.columns)

# create subplots
fig, axes = plt.subplots(nrows=len(dd.columns), ncols=len(dp.columns), figsize=(15, 6))

# flatten axes into a 1d array
axes = axes.flat

# iterate and plot
for (x, y), ax in zip(combos, axes):
    ax.scatter(dd[x], dp[y])
    ax.set(title=f'{x} vs. {y}', xlabel=x, ylabel=y)
plt.tight_layout()
plt.show()

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
2

just do a double for loop

n_rows = len(dummy_data.columns)
n_cols = len(dummy_predicted.columns)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15,6))

for row, data_col in enumerate(dummy_data):
    for col, pred_col in enumerate(dummy_predicted):
        ax = axes[row][col]

        ax.scatter(dummy_data[data_col], dummy_predicted[pred_col])
        
        ax.set_title(f'{data_col} vs {pred_col}')

        ax.set_xlabel(data_col)

plt.tight_layout() 
plt.show()

Output:

enter image description here

Quang Hoang
  • 146,074
  • 10
  • 56
  • 74