1

I have a dataframe, from which I take several data-groups and display as a displot on the same figure (overlayed). I also display a table summarizing some data regarding each group. I would like to display each row in the table (=each group) in the same color as the matching displot color. I've tried to define a common colormap to both the table and the displot, however the displot throws an error:

in distplot
    if kde_color != color:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Process finished with exit code 1

Here is the code:


fig, (ax_plot, ax_table) = plt.subplots(nrows=2, figsize=(11.69, 8.27),
                                            gridspec_kw=dict(height_ratios=[3, 1]) )    
ax_table.axis("off")

item_types = item_df['item_type'].unique()
columns = ('item type', 'Average DR', 'Percent DR passed 50%', 'Percent DR passed 60%', 'Percent DR passed 70%',
           'Percent DR passed 80%')
cell_text = []
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, len(item_types)))
i=0
    
for item_type in item_types:
    item_dr = item_df[item_df['item_type'] == item_type]['interesting_feature'].values
    color = table_colors[i, 0:3]
    sns.distplot(item_dr, hist=False, label=item_type, ax=ax_plot, color=mcolors.rgb_to_hsv(color))
    i += 1
    avg_dr = np.mean(item_dr)
    pass50 = len(item_dr[item_dr > 0.5]) / len(item_dr)
    pass60 = len(item_dr[item_dr > 0.6]) / len(item_dr)
    pass70 = len(item_dr[item_dr > 0.7]) / len(item_dr)
    pass80 = len(item_dr[item_dr > 0.8]) / len(item_dr)

    cell_text.append([str(item_type), str(avg_dr), str(pass50), str(pass60), str(pass70), str(pass80)])
item_table = ax_table.table(cellText=cell_text,
                            colLabels=columns,
                            loc='center',
                            fontsize=20,
                            rowColours=table_colors)
JohanC
  • 71,591
  • 8
  • 33
  • 66
Yael Zahar
  • 13
  • 3
  • 1
    Welcome to Stack Overflow! Please take a moment to read [How do I ask a good question?](https://stackoverflow.com/help/how-to-ask). You need to provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) **that includes a toy dataset** (refer to [How to make good reproducible pandas examples](https://stackoverflow.com/questions/20109391/how-to-make-good-reproducible-pandas-examples)) – Diziet Asahi Jul 10 '20 at 11:34

1 Answers1

0

First off, converting to hsv as in mcolors.rgb_to_hsv(color) doesn't look very useful.

Now, the main problem seems to be that passing a color as a list or a numpy array ([1, 0, 0]) confuses sns.distplot(..., color=color). Many seaborn functions allow either one color or a list of colors, and don't distinguish between a color passed as RGB values and an array. The workaround is to convert the list to a tuple: sns.distplot(..., color=tuple(color)).

Here is a minimal example:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))

fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
    color = table_colors[i, 0:3]
    # sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
    sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)

columns = list('abcdef')
num_columns = len(columns)
ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
               colLabels=columns, loc='center', fontsize=20,
               cellColours=np.repeat(table_colors, num_columns, axis=0).reshape(num_colors, num_columns, -1))
ax_table.axis('off')
plt.tight_layout()
plt.show()

example plot

To change the color of the text, you can loop through the cells of the table. As these particular colors are not very visible on a white background, the cell background could be set to black.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))

fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
    color = table_colors[i, :]
    # sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
    sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)

columns = list('abcdef')
num_columns = len(columns)
table = ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
                       colLabels=columns, loc='center', fontsize=20)
for i in range(num_colors):
    for j in range(num_columns):
        table[(i+1, j)].set_color('black')   # +1: skip the table header
        table[(i+1, j)].get_text().set_color(table_colors[i, :])
ax_table.axis('off')
plt.tight_layout()
plt.show()

changing the text colors

JohanC
  • 71,591
  • 8
  • 33
  • 66
  • Many thanks! I've converted the list into tuple and indeed it worked! A continuation question - Is it also possible to color the text in each row according to the color-key in the table's left column? – Yael Zahar Jul 10 '20 at 19:23
  • You can use `cellColours` instead of `rowColours`. The colors need to be repeated for every column. – JohanC Jul 10 '20 at 20:27
  • Perfect! Many thanks for your help. One last question... Is it possible to define the color of the fonts themselves (the text), and not the cell's background, to be according to the selected colormap? – Yael Zahar Jul 10 '20 at 21:05