0

I am trying to create subplots inside for loop for various columns of the dataset. I am using the California housing dataset from sklearn. So, there are 4 columns and I want to display three figures for each column in a subplot. I have provided the code which I have tried. Can somebody help me with this issue? Can we make it dynamic so that if I need to add more figure then we can add easily with title?

from sklearn.datasets import fetch_california_housing
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

california_housing = fetch_california_housing(as_frame=True)
# california_housing.frame.head()
features_of_interest = ["AveRooms", "AveBedrms", "AveOccup", "Population"]
california_housing.frame[features_of_interest]

fig, axes = plt.subplots(4, 3)

for cols in features_of_interest:
    # scatterplot
    sns.scatterplot(x=california_housing.frame[cols], y=california_housing.target)
    # histogram
    sns.histplot(x=california_housing.frame[cols], y=california_housing.target)
    #qqplot
    sm.qqplot(california_housing.frame[cols], line='45')
    plt.show()

Bad Coder
  • 177
  • 11
  • If my post did not answer your question, let me know what are you looking! I can make some improvements and you can accept it! :-) – GregOliveira Aug 20 '22 at 16:35

1 Answers1

3

There are some problems with your code:

  • you need to import statsmodels.api as sm

  • you need to use the ax parameter from scatterplot, histplot, and qqplot to indicate where the plot will be present

  • the way that you load the data isnot allowing matplotlib and seaborn to use the data. I made some changes on this part.

  • you do not need to show on each iteration just at the end.

from sklearn.datasets import fetch_california_housing
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm

california_housing = fetch_california_housing(as_frame=True).frame
features_of_interest = ["AveRooms", "AveBedrms", "AveOccup", "Population"]

fig, axes = plt.subplots(len(features_of_interest), 3)

for i, cols in enumerate(features_of_interest):
    # scatterplot
    sns.scatterplot(x=california_housing[cols], y=california_housing['MedHouseVal'], ax=axes[i,0])
    # histogram
    sns.histplot(x=california_housing[cols], y=california_housing['MedHouseVal'], ax=axes[i,1])
    #qqplot
    sm.qqplot(california_housing[cols], line='45', ax=axes[i,2])

plt.show()

PS.: I used len(features_of_interest) to auto-adapt our script considering the number of features.

Michael S.
  • 3,050
  • 4
  • 19
  • 34
GregOliveira
  • 151
  • 10