2

Is there a way to do a grid with scatterplots from all columns from a dataframe, where Y is one of the dataframe columns?

I can do a for loop on either matplotlib or seabornfor this (see codes below), but I can't make them show on a grid.

I want them to be displayed in grid visualization to make it easier to compare them.

This is what I CAN do:

for col in boston_df:
    plt.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

or

for col in boston_df:
    sns.regplot(x=boston_df[col], y=boston_df["MEDV"])
    plt.show()

Now if I try to create a subplot for example and use ax.scatter() in my loop like this

fig, ax = plt.subplots(3, 5,figsize=(16,6))
for col in boston_df:
    ax.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

it gives me the error AttributeError: 'numpy.ndarray' object has no attribute 'scatter'

It would be beautiful to find some solution simple like this:

df.hist(figsize=(18,10), density=True, label=df.columns)
plt.show()
Matthew Barlowe
  • 2,229
  • 1
  • 14
  • 24

1 Answers1

4

Consider using the ax argument of pandas DataFrame.plot and seaborn's regplot:

fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(boston_df.columns[1:]):
     #boston_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=boston_df[col], y=boston_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)      # TO ACCOMMODATE TITLE

plt.show()

To demonstrate with random data:

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,6), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 'COL5'])

### PLOT BUILD
fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()

Plot Output

For multiple rows across multiple columns, adjust the assignment to ax which is a numpy array using indexes: ax[row_idx, col_idx].

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,14), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 
                                    'COL5', 'COL6', 'COL7', 'COL8', 'COl9', 
                                    'COL10', 'COL11', 'COL12', 'COL13'])

### PLOT BUILD
fig, ax = plt.subplots(2, 7, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     if i <= 6:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[0,i])
     else:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[1,i-7])     

ax[1,6].axis('off')                  # HIDES AXES ON LAST ROW AND COL

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()

Multiple Rows Subplots

Parfait
  • 104,375
  • 17
  • 94
  • 125
  • This code gives me scatterplots only for the first 5 columns of the data (skipping the first one which is the target), but I have 13 columns in there. (it also throws the error `IndexError: index 5 is out of bounds for axis 0 with size 5`, although it still displays the plots.) If I try to do 2 rows changing the `fig, ax = plt.subplots(2, 5, figsize=(16,6))` from 1 to 2 rows, then I get the error `AttributeError: 'numpy.ndarray' object has no attribute 'scatter'` – Giovanna Fernandes Jun 02 '19 at 16:07
  • You need to adjust `subplots` to accommodate all your columns. Hence the `IndexError`. This solution assumes `MEDV` is first column and that you do NOT want to run scatterplot on itself. Simply adjust *ncol* of 5 in `plt.subplots()` to 13. And for multiple rows see extended answer where assignment to `ax` must be adjusted. – Parfait Jun 02 '19 at 16:33
  • It works with 2, 7 - I can't make it work with 3, 5... I'm sure it's just that, being a beginner, I don't fully understand what's going on with the code and can't thus make the adjustments myself. I'll keep trying, I probably have to add an elif if I want 3 rows? In any case the solution above with 2, 7 did work! – Giovanna Fernandes Jun 03 '19 at 17:41
  • Really? There is nothing preventing this from working in either version. I never worked in Python 2! Maybe your libraries or environments are mixed. In future, post actual data sample for us to help. Happy coding! – Parfait Jun 03 '19 at 17:48