0

I'm trying (and failing) to obtain a nested boxplot starting from a numpy array with dimension 3, for example A = np.random.uniform(size = (4,100,2)).

The kind of plot I'm referring to is represented in the next picture, which comes from the seaborn boxplot docs.

nested boxplot

JohanC
  • 71,591
  • 8
  • 33
  • 66
deppep
  • 135
  • 1
  • 1
  • 7

1 Answers1

0

You can use np.meshgrid() to generate 3 columns which index the 3D array. Unraveling these arrays makes them suitable as input for seaborn. Optionally, these arrays can be converted to a dataframe, which helps in automatically generating labels.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

A = np.random.normal(0.02, 1, size=(4, 100, 2)).reshape(-1).cumsum().reshape(4, -1, 2)
x_names = ['A', 'B', 'C', 'D']
hue_names = ['x', 'y']
dim1, dim2, dim3 = np.meshgrid(x_names, np.arange(A.shape[1]), hue_names, indexing='ij')
sns.boxplot(x=dim1.ravel(), y=A.ravel(), hue=dim3.ravel())

plt.tight_layout()
plt.show()

boxplot from 3d array

To create a dataframe, the code could look like the following. Note that the numeric second dimension isn't needed explicitly for the boxplot.

df = pd.DataFrame({'dim1': dim1.ravel(),
                   'dim2': dim2.ravel(),
                   'dim3': dim3.ravel(),
                   'A': A.ravel()})

# some tests to be sure that the 3D array has been interpreted well
assert (A[0, :, 0].sum() == df[(df['dim1'] == 'A') & (df['dim3'] == 'x')]['values'].sum())
assert (A[2, :, 1].sum() == df[(df['dim1'] == 'C') & (df['dim3'] == 'y')]['values'].sum())

sns.boxplot(data=df, x='dim1', y='A', hue='dim3')

If the array or the names are very long, working fully numeric would use less memory and speed things up:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

A = np.random.normal(0.01, 1, size=(10, 1000, 2)).reshape(-1).cumsum().reshape(10, -1, 2)
dim1, dim2, dim3 = np.meshgrid(np.arange(A.shape[0]), np.arange(A.shape[1]), np.arange(A.shape[2]), indexing='ij')
sns.set_style('whitegrid')
ax = sns.boxplot(x=dim1.ravel(), y=A.ravel(), hue=dim3.ravel(), palette='spring')
ax.set_xticklabels(["alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota", "kappa"])
ax.legend(handles=ax.legend_.legendHandles, labels=['2019-2020', '2020-2021'], title='Year')
sns.despine()
plt.tight_layout()
plt.show()

longer example for sns.boxplot from a 3d array

JohanC
  • 71,591
  • 8
  • 33
  • 66