3

In a Jupyter Notebook I am visualizing the Iris dataset with seaborn in combination with ipywidgets. That works fine, except that is not that fast because the plots have to be rendered every time you select a new combination of the species 'versicolor', 'virginica' and 'setosa'. See first code block.

So I tried to speed up the interaction by pre-processing the plots for each combination of species and storing them in a dictionary. See second code block. The dictionary seems to contain all plots, but they don't show.

Any suggestions how to fix this?

First code block:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from ipywidgets import *

sns.set(style="white")
iris = sns.load_dataset("iris")

def iris_pg(species):
    g = sns.PairGrid(iris[iris.species.isin(species)], diag_sharey=False)
    g.map_lower(sns.kdeplot)
    g.map_upper(sns.scatterplot)
    g.map_diag(sns.kdeplot, lw=3)
    return plt.show()

interact(iris_pg,
         species = widgets.SelectMultiple(options=iris.species.unique(),
                                          value=tuple(iris.species.unique()[-2:]),
                                          rows=len(iris.species.unique()),
                                          description='species',
                                          disabled=False))

Second code block:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from ipywidgets import *
from itertools import combinations

sns.set(style="white")
iris = sns.load_dataset("iris")

all_combinations = list()
for i in range(1, len(iris.species.unique()) + 1):
    for combi in combinations(iris.species.unique(), i):
        all_combinations.append(combi)

all_plots = dict()
for i in all_combinations:
    all_plots[i] = sns.PairGrid(iris[iris.species.isin(i)], diag_sharey=False)
    all_plots[i].map_lower(sns.kdeplot)
    all_plots[i].map_upper(sns.scatterplot)
    all_plots[i].map_diag(sns.kdeplot, lw=3)

def iris_pg(species):
    all_plots[species]
    return plt.show()

options = iris.species.unique()
value = tuple(iris.species.unique()[-2:])
rows = len(iris.species.unique())

interact(iris_pg,
         species = widgets.SelectMultiple(options=options,
                                          value=value,
                                          rows=rows,
                                          description='species',
                                          disabled=False))
René
  • 4,594
  • 5
  • 23
  • 52
  • I suppose the bottleneck is not the data aggregation, but the plot creation. So even if the second solution worked, it would probably not be sigificantly faster. Unfortunately, `PairGrid` needs to create new figures, so the only option would then be to replicate the figure with pure matplotlib and fill it with kdeplots/scatterplots on the fly. This would also require an interactive backend to be used, like `%matplotlib notebook`. – ImportanceOfBeingErnest Jul 26 '18 at 21:01
  • Thanks for your comment. The answer below gives me the performance improvement I was looking for. – René Jul 27 '18 at 09:13

1 Answers1

1

Based on the answer to this question, this is the solution for optimizing the performance of the interaction.

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from ipywidgets import *
from itertools import combinations

sns.set(style="white")
iris = sns.load_dataset("iris")

all_combinations = list()
for i in range(1, len(iris.species.unique()) + 1):
    for combi in combinations(iris.species.unique(), i):
        all_combinations.append(combi)

all_plots = dict()
for i in all_combinations:
    all_plots[i] = sns.PairGrid(iris[iris.species.isin(i)], diag_sharey=False)
    all_plots[i].map_lower(sns.kdeplot)
    all_plots[i].map_upper(sns.scatterplot)
    all_plots[i].map_diag(sns.kdeplot, lw=3)
    plt.close() # <-- added

def iris_pairgrid(species):
    return all_plots[species].fig # <-- added .fig

o = iris.species.unique()
v = tuple(iris.species.unique()[-2:])
r = len(iris.species.unique())

interact(iris_pairgrid,
         species = widgets.SelectMultiple(options=o,
                                          value=v,
                                          rows=r,
                                          description='species',
                                          disabled=False))
René
  • 4,594
  • 5
  • 23
  • 52