The issue
I have a dataframe with two columns: color and values.
Color can be red, yellow, green and black. For each of these 4 colors, I need to plot an histogram with the distribution of "values". I would like the plots to be arranged in a 2x2 grid.
What I would like to do
I would like to automate the creation of the plots - with the FacetGrid function or some equivalent.
In the examples I have seen, histogram are facetted by subsets of data, with one variable over the columns and one over the row. E.g here: https://seaborn.pydata.org/examples/faceted_histogram.html there are 3 species, 2 sexes, and 6 charts.
What I have tried but doesn't work
I have tried the FacetGrid function, but it produces 4 identical charts. It is clear I am doing something wrong when I define the FacetGrid or call the map() method, but I'm not sure what.
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
df=pd.DataFrame()
df['values'] = np.hstack([np.ones(25),
np.arange(200,225),
np.ones(25) *5,
np.linspace(500,550,25)])
df['category'] = np.repeat(['red','yellow','green','black'],25)
fig = sns.FacetGrid(df, col='category', col_wrap=2)
fig.map(sns.histplot, data = df, x='values', stat='density' )
In my data, there is only one categorical variable, not 2. This categorical variable can take 4 values and I'd like the plots in a 2x2 grid.
What I have got to work
I can create a figure with 2x2 subplots, and manually populate each of the 4 axes. It works, but it's neither elegant nor efficient
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
df=pd.DataFrame()
df['values'] = np.hstack([np.ones(25),
np.arange(200,225),
np.ones(25) *5,
np.linspace(500,550,25)])
df['category'] = np.repeat(['red','yellow','green','black'],25)
fig2, ax = plt.subplots(2,2)
categories = np.array([['red','yellow'],['green', 'black']])
for r in range(2):
for c in range(2):
cat = categories[r,c]
print(cat)
sns.histplot(df.query("category == @cat"), x="values", kde= True, stat='density', bins =10, ax = ax[r,c])
plt.tight_layout()
plt.show()