1

I'm working on a script that receives several inputs, parses the data and calls a plotting function several times, according to the number of nodes.

The issue is that I call my plotting function multiple times (see code below), but I don't know how to solve this issue. I saw this solution, but it's not really my case (or I don't know how to apply to my case).

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set(style="whitegrid")
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))
plt.tight_layout()


def plot_data(df, nodes):
  global ax1, ax2, ax3, ax4
  if nodes == 10:
    plt.subplot(141)
    ax1 = sns.kdeplot(df['Metric'], cumulative=True, legend=False)
    ax1.set_ylabel('ECDF', fontsize = 16)
    ax1.set_title('10 Nodes')

  elif nodes == 20:
    plt.subplot(142)
    ax2 = sns.kdeplot(df['Metric'], cumulative=True, legend=False)
    plt.setp(ax2.get_yticklabels(), visible=False)
    ax2.set_title('20 Nodes')

  elif nodes == 30:
    plt.subplot(143)
    ax3 = sns.kdeplot(df['Metric'], cumulative=True, legend=False)
    plt.setp(ax3.get_yticklabels(), visible=False)
    ax3.set_title('30 Nodes')

  elif nodes == 40:
    plt.subplot(144)
    ax4 = sns.kdeplot(df['Metric'], cumulative=True, legend=False)
    plt.setp(ax4.get_yticklabels(), visible=False)
    ax4.set_title('40 Nodes')


df1 = pd.DataFrame({'Metric':np.random.randint(0, 15, 1000)})    
df2 = pd.DataFrame({'Metric':np.random.randint(0, 15, 1000)})    
df3 = pd.DataFrame({'Metric':np.random.randint(0, 15, 1000)})    

nodes = [10, 20, 30, 40]
for i in range(4):
  """
  In my real code, the DataFrames are calculated from reading CSV files.
  Since that code would be too long, I'm using dummy data. 
  """
  plot_data(df1, nodes[i])
  # I understand that this calls cause the warning, 
  # but I don't know how to solve it
  plot_data(df2, nodes[i])
  plot_data(df3, nodes[i])
plt.show()  
DavidG
  • 24,279
  • 14
  • 89
  • 82
  • 1
    I don't really understand why you're using multiple `if`s rather than `if` `elif` and also why you're not passing the axarr object from subplots into the function (with the nodes list) and carrying out the loop in the function rather than calling the function 3 times? – Andrew Dec 06 '18 at 15:11

2 Answers2

1

You need to remove plt.subplot(nnn). As the warning says, doing this currently will reuse the axes instance. But in future matplotlib versions, this will create a new axes instance.

The solution is to pass the axes you have created as an argument to your function and use the ax= argument of seaborn.kdeplot:

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))
plt.tight_layout()

def plot_data(df, nodes, axes):
    ax1, ax2, ax3, ax4 = axes
    if nodes == 10:
        sns.kdeplot(df['Metric'], cumulative=True, legend=False, ax=ax1)
        ax1.set_ylabel('ECDF', fontsize = 16)
        ax1.set_title('10 Nodes')
    elif nodes == 20:
        sns.kdeplot(df['Metric'], cumulative=True, legend=False, ax=ax2)
        plt.setp(ax2.get_yticklabels(), visible=False)
        ax2.set_title('20 Nodes')
    elif nodes == 30:
        sns.kdeplot(df['Metric'], cumulative=True, legend=False, ax=ax3)
        plt.setp(ax3.get_yticklabels(), visible=False)
        ax3.set_title('30 Nodes')
    else:
        sns.kdeplot(df['Metric'], cumulative=True, legend=False, ax=ax4)
        plt.setp(ax4.get_yticklabels(), visible=False)
        ax4.set_title('40 Nodes')

for i in range(4):
    plot_data(df1, nodes[i], axes)
    plot_data(df2, nodes[i], axes)
    plot_data(df3, nodes[i], axes)
plt.show()

enter image description here

Note that you could make the above simpler by using sharey=True in fig, axes = plt.subplots(…, sharey=True) and removing plt.setp(ax.get_yticklabels(), visible=False)

DavidG
  • 24,279
  • 14
  • 89
  • 82
  • Note that the warning is there because `pyplot.subplot` currently uses the same code path as `.add_subplot`. The warning only applies to the latter, because in pyplot activating a subplot after creation is of course useful. So I guess the solution here is still the recommended way to do it, but using pyplot.subplot is equally possible. – ImportanceOfBeingErnest Dec 06 '18 at 19:56
0

This should do what you need, I think - just a case of passing the axes as arguments and then putting the loop into the function

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set(style="whitegrid")
fig, axarr = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))
plt.tight_layout()

nodes = [10, 20, 30, 40]

def plot_data(list_of_dfs, axarr, nodes):

    for df, ax, node in zip(list_of_dfs, axarr, nodes):
        ax = sns.kdeplot(df['Metric'], cumulative=True, legend=False)#I'm not completely sure this needs to be assignment, haven't used seaborn much
        ax.set_ylabel('ECDF', fontsize = 16)
        ax.set_title('{} Nodes'.format(nodes))

list_of_dfs = [df1, df2, df3]
plot_data(list_of_dfs, axarr, nodes)
plt.show()  
Andrew
  • 1,072
  • 1
  • 7
  • 15