0

I am trying to customize y-labels of a clustermap from seaborn with a multiindex dataframe. So I have a dataframe that looks like this :

                    Col1    Col2    ...
Idx1.A    Idx2.a    1.05    1.51    ...
          Idx2.b    0.94    0.88    ...
Idx1.B    Idx2.c    1.09    1.20    ...
          Idx2.d    0.90    0.79    ...
   ...       ...     ...     ...    ...

The goal is to have the same y-labels like that, where in my example Idx1 would be the seasons, Idx2 would be the months and the Cols would be the years (except that it's a clustermap, not a heatmap - so I think the functions from the seaborn classes are different when customizing the ticks -, though clustermap just "add" a hierarchic clustering on a heatmap over rows or columns): enter image description here My code :

def do_clustermap():
    with open('/home/Documents/myfile.csv', 'r') as f:
        df = pd.read_csv(f, index_col=[0, 1], sep='\t')

        g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
        g.ax_heatmap.yaxis.set_ticks_position("left")

        plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=4)
        plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=4)
        plt.show()

I tried to follow the answers from this thread but it gives this message :

UserWarning: Clustering large matrix with scipy. Installing `fastcluster` may give better performance.
  warnings.warn(msg)
Traceback (most recent call last):
  File "/home/ju/PycharmProjects/stage/figures.py", line 24, in <module>
    do_heatmap()
  File "/home/ju/PycharmProjects/stage/figures.py", line 13, in do_heatmap
    ax = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1412, in clustermap
    tree_kws=tree_kws, **kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1223, in plot
    tree_kws=tree_kws)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1079, in plot_dendrograms
    tree_kws=tree_kws
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 776, in dendrogram
    label=label, rotate=rotate)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 584, in __init__
    self.linkage = self.calculated_linkage
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 651, in calculated_linkage
    return self._calculate_linkage_scipy()
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 620, in _calculate_linkage_scipy
    metric=self.metric)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1038, in linkage
    y = _convert_to_double(np.asarray(y, order='c'))
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1560, in _convert_to_double
    X = X.astype(np.double)
ValueError: could not convert string to float: 'Col1'

Anyone has an idea ? Here a small example of the file I'm working with:

        Robert  Jean    Lulu
Bar a   1.05    1.52    1.16
Bar b   0.94    0.49    0.83
Foo c   1.09    1.22    1.44
Foo d   0.92    0.79    0.55
Hop e   0.62    0.82    0.68
Hop f   0.52    0.18    0.31
Hop g   0.93    1.15    1.11
user15775608
  • 61
  • 1
  • 4
  • Please explain what you mean by *"it doesn't work"*. Does it crash? Does it give an unexpected plot? ... ? Which error messages did you get? What did you try to solve them? Note that `plt.tight_layout()` might help fit the labels. – JohanC May 23 '21 at 17:13
  • Sorry, I edited my message so you can see the error message. Actually, I don't understand why it wants to convert my columns string to float. I don't have this error (and it works fine) when I remove the super index (first column) from the dataframe. Okay, I'll look how plt.tight_layout() works, thank you – user15775608 May 23 '21 at 20:28
  • Did you try this with the latest matplotlib/pandas/seaborn versions? Also note that `sns.clustermap()` doesn't return an `ax`, but a `ClusterGrid`. Often, the return value of such a function is named `g`, as calling it `ax` is extremely confusing. – JohanC May 23 '21 at 20:57
  • Yes, I have them all updated (I don't know if it's relevant but I use the last version of Pycharm CE to run this script). Alright, thank you for this information, I'll change the return value – user15775608 May 23 '21 at 22:10
  • Thank you very much for your efforts, I added at the end of my first post the type of file I'm working with – user15775608 May 24 '21 at 21:50
  • Could you create a full minimal reproducible example? – JohanC May 24 '21 at 22:06
  • Could you create a minimal reproducible example? My attempts couldn't reproduce the error. – JohanC May 24 '21 at 22:16

1 Answers1

1

Here is some code creating a minimal example similar to the given data.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                   'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                   'Col1': np.random.rand(7),
                   'Col2': np.random.rand(7)})
df = df.set_index(['Idx1', 'Idx2'])

g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
g.ax_heatmap.yaxis.set_ticks_position("left")

plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=10)
plt.show()

The dataframe looks like:

               Col1      Col2
Idx1 Idx2                    
Bar  a     0.366961  0.253956
     b     0.320457  0.807694
Foo  c     0.293184  0.337154
     d     0.868155  0.661968
Hop  e     0.908930  0.406291
     f     0.670220  0.668903
     g     0.683821  0.476246

With seaborn 0.11.1, matplotlib 3.4.2, pandas 1.2.4 and scipy 1.6.3 following plot is generated:

clustermap example with two indices

An integration with the linked code could look like the following. Some distances will need to be adjusted depending on the

import matplotlib.pyplot as plt
from itertools import groupby
import seaborn as sns
import pandas as pd
import numpy as np

def add_line(ax, xpos, ypos):
    line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    xpos = -.2
    scale = 1./df.index.size
    for level in range(df.index.nlevels):
        pos = df.index.size
        for label, rpos in label_len(df.index,level):
            add_line(ax, pos*scale, xpos)
            pos -= rpos
            lypos = (pos + .5 * rpos)*scale
            ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes)
        add_line(ax, pos*scale , xpos)
        xpos -= .2

df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                   'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                   'Col1': np.random.rand(7),
                   'Col2': np.random.rand(7)})
df = df.set_index(['Idx2', 'Idx1'])

g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004, figsize=(10,5))
g.ax_heatmap.yaxis.set_ticks_position("left")

plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
g.ax_heatmap.set_yticks([])
label_group_bar_table(g.ax_heatmap, df)
g.fig.subplots_adjust(left=0.15)
plt.show()

sns.clustermap with double indices, custom labeling

JohanC
  • 71,591
  • 8
  • 33
  • 66