0

I am following this answer here that works for heatmap.

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

fig, ((ax1, cbar_ax), (ax2, dummy_ax)) = plt.subplots(nrows=2, ncols=2, figsize=(26, 16), sharex='col',
                                                      gridspec_kw={'height_ratios': [5, 1], 'width_ratios': [20, 1]})
missings_df = np.random.rand(3, 3)
sns.heatmap(missings_df.T, cmap="Blues", cbar_ax=cbar_ax, xticklabels=False, linewidths=2, ax=ax1)

ax2.set_xlabel('Time (hours)')

patient_counts = np.random.randint(10, 50, 3)
x_ticks = ['Time1', 'Time2', 'Time3']
x_tick_pos = [i + 0.5 for i in range(len(x_ticks))]
ax2.bar(x_tick_pos, patient_counts, align='center')
ax2.set_xticks(x_tick_pos)
ax2.set_xticklabels(x_ticks)
dummy_ax.axis('off')

plt.tight_layout()
plt.show()

enter image description here

Then, I tried to use sns.clustermap instead of sns.heatmap by just changing

sns.heatmap(missings_df.T, cmap="Blues", cbar_ax=cbar_ax, xticklabels=False, linewidths=2, ax=ax1)

to

sns.clustermap(missings_df.T, cmap="Blues", cbar_ax=cbar_ax, xticklabels=False, linewidths=2, ax=ax1)

that gave me

TypeError: seaborn.matrix.heatmap() got multiple values for keyword argument 'cbar_ax'

Any idea to also make this work for sns.clustermap?

My objective is to plot a bar plot showing the average of columns of the cluster map, above the cluster map.

Johnny Tam
  • 423
  • 4
  • 16
  • `sns.heatmap()` is an "axes level" function, and adapts itself to a given subplot. `sns.clustermap` is a "figure level" function, which creates its own figure with fixed subplots. It's quite hard to add more subplots to that figure or to reposition the colorbar. – JohanC Aug 30 '23 at 21:35
  • For some inspiration, you might take a look at [Add bar-plot along a particular axis of clustermap](https://stackoverflow.com/questions/54788526/add-bar-plot-along-a-particular-axis-of-clustermap-with-index-specific-data) – JohanC Aug 30 '23 at 21:56

0 Answers0