I am trying to add a colorbar and cannot figure out how to add it anywhere. Ideally it would be a small ax object added to the right of the right-most ax but I can't figure out how to add ax objects to ax objects that I've already created with make_axes_locatable
.
There are 2 modes, one that is multi-axes and one that is not. The one that isn't is straightforward as there are a lot of tutorials on StackOverflow but the multi-axes figure is the complicated bit.
How can I add a colorbar to the multi-axes figure? Preferably on the right but internal on the main scatter plot could work too (top left).
Here's my code:
# Plot compositional data
#!@check_packages(["matplotlib", "seaborn"])
def plot_compositions(
X:pd.DataFrame,
c:pd.Series="black", #"evenness"
s:pd.Series=28,
classes:pd.Series=None,
class_colors:dict=None,
edgecolor="white",
cbar=True,
cmap=plt.cm.gist_heat_r,
figsize=(13,8),
title=None,
style="seaborn-white",
ax=None,
show_xgrid=False,
show_ygrid=True,
show_density_1d=True,
show_density_2d=True,
show_legend=True,
xlabel=None,
ylabel=None,
legend_kws=dict(),
legend_title=None,
title_kws=dict(),
legend_title_kws=dict(),
axis_label_kws=dict(),
annot_kws=dict(),
line_kws=dict(),
hist_1d_kws=dict(),
# rug_kws=dict(),
kde_2d_kws=dict(),
cbar_kws=dict(),
panel_pad=0.1,
panel_size=0.618,
cax_panel_pad="5%",
cax_panel_size=0.618,
background_color="white",
sample_labels:dict=None,
logscale=True,
xmin=0,
ymin=0,
vmin=None,
vmax=None,
**scatter_kws,
):
"""
Plot compositions of total counts (x-axis) vs. number of detected components (y-axis)
"""
from collections.abc import Mapping
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import to_hex
from matplotlib.scale import LogScale
import seaborn as sns
# Defaults
_title_kws = {"fontsize":15, "fontweight":"bold"}
_title_kws.update(title_kws)
_legend_kws = {'fontsize': 12, 'frameon': True, 'facecolor': 'white', 'edgecolor': 'black'}#, 'loc': 'center left', 'bbox_to_anchor': (1, 0.5)}
_legend_kws.update(legend_kws)
_legend_title_kws = {"size":15, "weight":"bold"}
_legend_title_kws.update(legend_title_kws)
_axis_label_kws = {"fontsize":15}
_axis_label_kws.update(axis_label_kws)
_hist_1d_kws = {"alpha":0.0618}
_hist_1d_kws.update(hist_1d_kws)
# _kde_1d_kws = {"alpha":0.618} # Rug takes too long when theres a lot points
# _kde_1d_kws.update(kde_1d_kws)
# _rug_kws = {"height":0.5}
# _rug_kws.update(rug_kws)
_kde_2d_kws = {"shade":True, "alpha":0.618}
_kde_2d_kws.update(kde_2d_kws)
_line_kws = {"linewidth":1.618, "linestyle":":", "alpha":1.0}
_line_kws.update(line_kws)
_annot_kws = {}
_annot_kws.update(annot_kws)
_scatter_kws={"edgecolor":edgecolor, "linewidths":0.618}
_scatter_kws.update(scatter_kws)
# Data
X = X.fillna(0)
# check_compositional(X, acceptable_dimensions={2})
assert np.all(X == X.astype(int)), "X must be integer data and should not be closure transformed"
# Total number of counts
sample_to_totalcounts = X.sum(axis=1)
remove_samples = sample_to_totalcounts == 0
if np.any(remove_samples):
warnings.warn("Removing the following observations because depth = 0: {}".format(remove_samples.index[remove_samples]))
sample_to_totalcounts = sample_to_totalcounts[~remove_samples]
samples = sample_to_totalcounts.index
# Number of detected components
sample_to_ncomponents = (X > 0).sum(axis=1) #! sample_to_ncomponents = number_of_components(X.loc[samples], checks=False)
number_of_samples = sample_to_ncomponents.size
# Colors
if classes is not None:
assert class_colors is not None, "`class_colors` is required for using `classes`"
classes = pd.Series(classes)
assert np.all(classes.index == samples), "`classes` must be a pd.Series with the same index ordering as `X.index`"
assert np.all(classes.map(lambda x: x in class_colors)), "Classes in `class` must have a color in `class_colors`"
if c is not None:
warnings.warn("c will be ignored and superceded by class_colors and classes")
c = classes.map(lambda x: class_colors[x])
if legend_title is None:
legend_title = c.name
else:
classes = 0
if c is None:
c = "black"
if not isinstance(c, pd.Series):
c = pd.Series([c]*number_of_samples, index=samples)
assert np.all(c.notnull())
c_is_continuous = False
try:
c = c.map(to_hex)
except ValueError:
c_is_continuous = True
if vmin is None:
vmin = c.min()
if vmax is None:
vmax = c.max()
_scatter_kws["cmap"] = cmap
_scatter_kws["vmin"] = vmin
_scatter_kws["vmax"] = vmax
if not isinstance(classes, pd.Series):
classes = pd.Series([classes]*number_of_samples, index=samples)
# Marker size
if not isinstance(s, pd.Series):
s = pd.Series([s]*number_of_samples, index=samples)
assert np.all(s.notnull())
# Data
df_data = pd.DataFrame([
sample_to_totalcounts,
sample_to_ncomponents,
c,
s,
classes,
], index=["x","y","c","s","class"],
).T
for field in ["x","s"]:
df_data[field] = df_data[field].astype(float)
for field in ["y"]:
df_data[field] = df_data[field].astype(int)
# Plotting
number_of_classes = df_data["class"].nunique()
axes = list()
with plt.style.context(style):
# Set up axes
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = plt.gcf()
axes.append(ax)
# Scatter plot
for id_class in df_data["class"].unique():
index = df_data["class"][lambda x: x == id_class].index
df = df_data.loc[index]
ax.scatter(data=df, x="x", y="y", c="c", s="s", label=id_class if number_of_classes > 1 else None, **_scatter_kws)
# Limits
xlim = ax.get_xlim()
ylim = ax.get_ylim()
if logscale:
ax.set_xscale(LogScale(axis=0,base=10))
# Density (1D)
divider = make_axes_locatable(ax)
if show_density_1d:
ax_right = divider.append_axes("right", pad=panel_pad, size=panel_size)
ax_top = divider.append_axes("top", pad=panel_pad, size=panel_size)
for id_class in df_data["class"].unique():
index = df_data["class"][lambda x: x == id_class].index
df = df_data.loc[index]
color = df.loc[index,"c"].values[0]
if c_is_continuous:
color = to_hex("black")
# Histogram
sns.histplot(data=df, x="x", color=color, ax=ax_top, kde=True, **_hist_1d_kws) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
sns.histplot(data=df, y="y", color=color, ax=ax_right, kde=True, **_hist_1d_kws) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
# KDE
# sns.rugplot(data=df, x="x", color=color, ax=ax_top,) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
# sns.rugplot(data=df, y="y", color=color, ax=ax_right) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{(data=df, y="y", color=color, ax=ax_right, **_kde_1d_kws, zorder=0) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
if logscale:
ax_top.set_xscale(LogScale(axis=0,base=10))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ax_right.set(ylim=ylim, xticklabels=[],yticklabels=[],yticks=ax.get_yticks())
ax_top.set(xlim=xlim, xticklabels=[],yticklabels=[],xticks=ax.get_xticks())
ax_right.set_xlabel(None)
ax_right.set_ylabel(None)
ax_top.set_xlabel(None)
ax_top.set_ylabel(None)
axes.append(ax_right)
axes.append(ax_top)
# Density (2D)
if show_density_2d:
for id_class in df_data["class"].unique():
index = df_data["class"][lambda x: x == id_class].index
df = df_data.loc[index]
color = df.loc[index,"c"].values[0]
if c_is_continuous:
color = to_hex("black")
try:
sns.kdeplot(data=df, x="x", y="y", color=color, zorder=0, ax=ax, **_kde_2d_kws)
except Exception as e:
warnings.warn("({}) Could not compute the 2-dimensional KDE plot for the following class: {}".format(id_class))
# Annotations
if sample_labels is not None:
assert hasattr(sample_labels, "__iter__"), "sample_labels must be an iterable or a mapping between sample and label"
if isinstance(sample_labels, (Mapping, pd.Series)):
sample_labels = dict(sample_labels)
else:
sample_labels = dict(zip(sample_labels, sample_labels))
for k,v in sample_labels.items():
if k not in df_data.index:
assert k in X.index, ("{} is not in X.index".format(k))
warnings.warn("{} is not in X.index after removing empty compositions".format(k))
else:
x, y = df_data.loc[k,["x","y"]]
ax.text(x=x, y=y, s=v, **_annot_kws)
# Labels
if xlabel is None:
xlabel = "Total Counts"
if logscale:
xlabel = "log$_{10}$(%s)"%(xlabel)
if ylabel is None:
ylabel = "Number of Components"
if xmin is not None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ax.set_xlim(xmin, max(ax.get_xlim()))
if ymin is not None:
ax.set_ylim(xmin, max(ax.get_ylim()))
ax.set_xlabel(xlabel, **_axis_label_kws)
ax.set_ylabel(ylabel, **_axis_label_kws)
ax.xaxis.grid(show_xgrid)
ax.yaxis.grid(show_ygrid)
# Cbar
# if c_is_continuous:
# if cbar:
# divider_cax = make_axes_locatable(axes[-1])
# cax = fig.add_axes([0.27, 0.8, 0.5, 0.05])
# im = ax.imshow(df_data["c"].values.reshape(-1,1), cmap=cmap)
# fig.colorbar(im, orientation='horizontal')
# plt.show()
# cax = divider_cax.append_axes("right", size="5%", pad="2%") # Adjust the size and pad as needed
# im = ax.imshow(df_data["c"].values.reshape(-1, 1), cmap=cmap, vmin=vmin, vmax=vmax)
# plt.colorbar(im, cax=cax)
# Legend
if show_legend:
if number_of_classes > 1:
ax.legend(**_legend_kws)
if bool(legend_title):
ax.legend_.set_title(legend_title, prop=_legend_title_kws)
# Title
if title is not None:
axes[-1].set_title(title, **_title_kws)
# Background color
for ax_query in axes:
ax_query.set_facecolor(background_color)
return fig, axes
# Load abundances (Gomez and Espinoza et al. 2017)
X = pd.read_csv("https://github.com/jolespin/projects/raw/main/supragingival_plaque_microbiome/16S_amplicons/Data/X.tsv.gz",
sep="\t",
index_col=0,
compression="gzip",
)
Y = pd.read_csv("https://github.com/jolespin/projects/raw/main/supragingival_plaque_microbiome/16S_amplicons/Data/Y.tsv.gz",
sep="\t",
index_col=0,
compression="gzip",
)
classes = Y["Caries_enamel"].loc[X.index]
c = pd.Series(classes.map(lambda x: {True:"blue", False:"green"}[x == "NO"]))
sample_labels = pd.Index(X.sum(axis=1).sort_values().index[:4].tolist())
sample_labels = pd.Series(sample_labels.map(lambda x: x.split("_")[0]), sample_labels)
# fig, axes = plot_compositions(X, s=28,logscale=False, sample_labels=sample_labels, class_colors={"NO":"black", "YES":"red"}, classes=classes, title="Caries")
fig, axes = plot_compositions(X,c=Y.loc[X.index,"age (months)"], s=28,logscale=False, sample_labels=sample_labels, title="Caries")