1

I have a training data set and test data set with the same categorical columns. Currently, I enumerate through the categorical columns for each data set to produce two sets of countplot subplots for each data set as follows:

plt.figure(figsize=(20,20))
for i, col in enumerate(cat_features):
    plt.subplot(5,2,i+1)
    sns.countplot(x=col,data=train, order=('A','B','C','D','E','F','G','H','I','J','K','L','N'))
plt.tight_layout()

which produces a nice figure like this (note for sake of space I cropped to show the first four):

What I want to be ab;e to do is a side by side comparison between Test and Train; one set of subplots where catplot for Cat0 Train is side by side with Cat0 Test, then subplot catplot for Cat1 Train is next to Cat1 Test, etc,etc.

Train Data looks like (small subset)

cat0    cat1    cat2    cat3    cat4    cat5    cat6    cat7    cat8    
                                    
    A    B       A      A       B       D       A       E       C   
    B    A       A      A       B       B       A       E       A   
    A    A       A      C       B       D       A       B       C   
    A    A       A      C       B       D       A       E       G   
    A    B       A      A       B       B       A       E       C   

Train Data

cat0    cat1    cat2    cat3    cat4    cat5    cat6    cat7    cat8
                            
A       B       A       C       B       D       A       E       E
A       B       A       C       B       D       A       E       C
A       B       A       C       B       D       A       E       C
A       A       B       A       B       D       A       E       E
A       B       A       A       B       B       A       E       E
Ncosgove
  • 55
  • 1
  • 5
  • My apologies, I noted a typo.It should be "here catplot for Cat0 Train is side by side sith Cat0 Test, then subplot catplot for Cat1 Train is next to Cat1 Test, etc,etc" – Ncosgove Feb 07 '21 at 19:00
  • Please edit your question if you want to add information. – Mr. T Feb 07 '21 at 19:57

1 Answers1

0

It's hard to know without some sample data but you can create the four plots as below, then loop through them and the desired order of the datasets, plotting to the relevent axis.

import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(ncols=2, nrows=2)

for ax, dataset in zip(axes.flatten(), [train, test, train, test]):
    sns.countplot(
      data = dataset,
      x=cat_features,
      order = ('A','B','C','D','E','F','G','H','I','J','K','L','N'),
      ax=ax)
    
plt.show()
mullinscr
  • 1,668
  • 1
  • 6
  • 14