1

I am trying to add a colorbar and cannot figure out how to add it anywhere. Ideally it would be a small ax object added to the right of the right-most ax but I can't figure out how to add ax objects to ax objects that I've already created with make_axes_locatable.

There are 2 modes, one that is multi-axes and one that is not. The one that isn't is straightforward as there are a lot of tutorials on StackOverflow but the multi-axes figure is the complicated bit.

How can I add a colorbar to the multi-axes figure? Preferably on the right but internal on the main scatter plot could work too (top left).

Here's my code:

# Plot compositional data
#!@check_packages(["matplotlib", "seaborn"])
def plot_compositions(
    X:pd.DataFrame,
    c:pd.Series="black", #"evenness"
    s:pd.Series=28,
    classes:pd.Series=None,
    class_colors:dict=None,

    edgecolor="white",
    cbar=True,
    cmap=plt.cm.gist_heat_r,
    figsize=(13,8),
    title=None,
    style="seaborn-white",
    ax=None,

    show_xgrid=False,
    show_ygrid=True,
    show_density_1d=True,
    show_density_2d=True,
    show_legend=True,

    xlabel=None,
    ylabel=None,
    legend_kws=dict(),
    legend_title=None,
    
    title_kws=dict(),
    legend_title_kws=dict(),
    axis_label_kws=dict(),
    annot_kws=dict(),
    line_kws=dict(),
    hist_1d_kws=dict(),
    # rug_kws=dict(),
    kde_2d_kws=dict(),
    cbar_kws=dict(),
    
    panel_pad=0.1,
    panel_size=0.618,
    cax_panel_pad="5%",
    cax_panel_size=0.618,
    
    background_color="white",
    sample_labels:dict=None,
    logscale=True,
    xmin=0, 
    ymin=0, 
    vmin=None,
    vmax=None,
    **scatter_kws,

    ):
    """
    Plot compositions of total counts (x-axis) vs. number of detected components (y-axis)
    """
    from collections.abc import Mapping
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    from matplotlib.colors import to_hex
    from matplotlib.scale import LogScale

    import seaborn as sns
        
    # Defaults
    _title_kws = {"fontsize":15, "fontweight":"bold"}
    _title_kws.update(title_kws)
    _legend_kws = {'fontsize': 12, 'frameon': True, 'facecolor': 'white', 'edgecolor': 'black'}#, 'loc': 'center left', 'bbox_to_anchor': (1, 0.5)}
    _legend_kws.update(legend_kws)
    _legend_title_kws = {"size":15, "weight":"bold"}
    _legend_title_kws.update(legend_title_kws)
    _axis_label_kws = {"fontsize":15}
    _axis_label_kws.update(axis_label_kws)
    _hist_1d_kws = {"alpha":0.0618}
    _hist_1d_kws.update(hist_1d_kws)
    
    # _kde_1d_kws = {"alpha":0.618} # Rug takes too long when theres a lot points
    # _kde_1d_kws.update(kde_1d_kws)
    # _rug_kws = {"height":0.5}
    # _rug_kws.update(rug_kws)
        
    _kde_2d_kws = {"shade":True,  "alpha":0.618}
    _kde_2d_kws.update(kde_2d_kws)
    _line_kws = {"linewidth":1.618, "linestyle":":", "alpha":1.0}
    _line_kws.update(line_kws)
    _annot_kws = {}
    _annot_kws.update(annot_kws)
    _scatter_kws={"edgecolor":edgecolor, "linewidths":0.618}
    _scatter_kws.update(scatter_kws)
        
    # Data
    X = X.fillna(0)
    # check_compositional(X, acceptable_dimensions={2})
    assert np.all(X == X.astype(int)), "X must be integer data and should not be closure transformed"
        
    # Total number of counts
    sample_to_totalcounts = X.sum(axis=1)
        
    remove_samples = sample_to_totalcounts == 0
    if np.any(remove_samples):
        warnings.warn("Removing the following observations because depth = 0: {}".format(remove_samples.index[remove_samples]))
        sample_to_totalcounts = sample_to_totalcounts[~remove_samples]

    samples = sample_to_totalcounts.index
        
    # Number of detected components
    sample_to_ncomponents = (X > 0).sum(axis=1) #! sample_to_ncomponents = number_of_components(X.loc[samples], checks=False)
    
    number_of_samples = sample_to_ncomponents.size
        
    # Colors
    if classes is not None:
        assert class_colors is not None, "`class_colors` is required for using `classes`"
        classes = pd.Series(classes)
        assert np.all(classes.index == samples), "`classes` must be a pd.Series with the same index ordering as `X.index`"
        assert np.all(classes.map(lambda x: x in class_colors)), "Classes in `class` must have a color in `class_colors`"
        if c is not None:
            warnings.warn("c will be ignored and superceded by class_colors and classes")
        c = classes.map(lambda x: class_colors[x])
        if legend_title is None:
            legend_title = c.name
    else:
        classes = 0
        if c is None:
            c = "black"

    if not isinstance(c, pd.Series):
        c = pd.Series([c]*number_of_samples, index=samples)
    assert np.all(c.notnull())
        
    c_is_continuous = False
    try:        
        c = c.map(to_hex)
    except ValueError:
        c_is_continuous = True
        if vmin is None:
            vmin = c.min()
        if vmax is None:
            vmax = c.max()
    _scatter_kws["cmap"] = cmap
    _scatter_kws["vmin"] = vmin
    _scatter_kws["vmax"] = vmax

    if not isinstance(classes, pd.Series):
        classes = pd.Series([classes]*number_of_samples, index=samples)
        
    # Marker size
    if not isinstance(s, pd.Series):
        s = pd.Series([s]*number_of_samples, index=samples)
    assert np.all(s.notnull())

    # Data
    df_data = pd.DataFrame([
        sample_to_totalcounts,
        sample_to_ncomponents,
        c,
        s,
        classes,
        ], index=["x","y","c","s","class"],
     ).T
    for field in ["x","s"]:
        df_data[field] = df_data[field].astype(float)
    for field in ["y"]:
        df_data[field] = df_data[field].astype(int)

    # Plotting
    number_of_classes = df_data["class"].nunique()
    axes = list()
    with plt.style.context(style):
        # Set up axes
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
        else:
            fig = plt.gcf()
        axes.append(ax)

        # Scatter plot
        for id_class in df_data["class"].unique():
            index = df_data["class"][lambda x: x == id_class].index
            df = df_data.loc[index]
            ax.scatter(data=df, x="x", y="y", c="c", s="s", label=id_class if number_of_classes > 1 else None, **_scatter_kws)


        # Limits
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        if logscale:
            ax.set_xscale(LogScale(axis=0,base=10))
            
        # Density (1D)
        divider = make_axes_locatable(ax)
        if show_density_1d:
            ax_right = divider.append_axes("right", pad=panel_pad, size=panel_size)
            ax_top = divider.append_axes("top", pad=panel_pad, size=panel_size)

            for id_class in df_data["class"].unique():
                index = df_data["class"][lambda x: x == id_class].index
                df = df_data.loc[index]
                color = df.loc[index,"c"].values[0]
                if c_is_continuous:
                    color = to_hex("black")
                    
                # Histogram
                sns.histplot(data=df, x="x", color=color,  ax=ax_top, kde=True, **_hist_1d_kws) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
                sns.histplot(data=df, y="y", color=color,  ax=ax_right, kde=True, **_hist_1d_kws) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
                
                # KDE
                # sns.rugplot(data=df, x="x", color=color,  ax=ax_top,) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{
                # sns.rugplot(data=df, y="y", color=color,  ax=ax_right) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{(data=df, y="y", color=color,  ax=ax_right, **_kde_1d_kws, zorder=0) # "rug":max(X.shape) < 5000, "hist_kws":{"alpha":0.382}, "kde_kws":{

            if logscale:
                ax_top.set_xscale(LogScale(axis=0,base=10))

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ax_right.set(ylim=ylim, xticklabels=[],yticklabels=[],yticks=ax.get_yticks())
                ax_top.set(xlim=xlim, xticklabels=[],yticklabels=[],xticks=ax.get_xticks())
            ax_right.set_xlabel(None)
            ax_right.set_ylabel(None)
            ax_top.set_xlabel(None)
            ax_top.set_ylabel(None)
            axes.append(ax_right)
            axes.append(ax_top)

        # Density (2D)
        if show_density_2d:
            for id_class in df_data["class"].unique():
                index = df_data["class"][lambda x: x == id_class].index
                df = df_data.loc[index]
                color = df.loc[index,"c"].values[0]
                if c_is_continuous:
                    color = to_hex("black")
                try:
                    sns.kdeplot(data=df, x="x", y="y", color=color, zorder=0, ax=ax, **_kde_2d_kws)
                except Exception as e:
                    warnings.warn("({}) Could not compute the 2-dimensional KDE plot for the following class: {}".format(id_class))


        # Annotations
        if sample_labels is not None:
            assert hasattr(sample_labels, "__iter__"), "sample_labels must be an iterable or a mapping between sample and label"
            
            if isinstance(sample_labels, (Mapping, pd.Series)):
                sample_labels = dict(sample_labels)
            else:
                sample_labels = dict(zip(sample_labels, sample_labels))
                
            for k,v in sample_labels.items():
                if k not in df_data.index:
                    assert k in X.index, ("{} is not in X.index".format(k))
                    warnings.warn("{} is not in X.index after removing empty compositions".format(k))
                else:
                    x, y = df_data.loc[k,["x","y"]]
                    ax.text(x=x, y=y, s=v, **_annot_kws)

        # Labels
        if xlabel is None:
            xlabel = "Total Counts"
            if logscale:
                xlabel = "log$_{10}$(%s)"%(xlabel)
        if ylabel is None:
            ylabel = "Number of Components"

        if xmin is not None:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ax.set_xlim(xmin, max(ax.get_xlim()))
       
        if ymin is not None:
            ax.set_ylim(xmin, max(ax.get_ylim()))

        ax.set_xlabel(xlabel, **_axis_label_kws)
        ax.set_ylabel(ylabel, **_axis_label_kws)
        ax.xaxis.grid(show_xgrid)
        ax.yaxis.grid(show_ygrid)

        # Cbar
        # if c_is_continuous:
            # if cbar:
                # divider_cax = make_axes_locatable(axes[-1])
                # cax = fig.add_axes([0.27, 0.8, 0.5, 0.05])
                # im = ax.imshow(df_data["c"].values.reshape(-1,1), cmap=cmap)
                # fig.colorbar(im,  orientation='horizontal')
                # plt.show()
            
            # cax = divider_cax.append_axes("right", size="5%", pad="2%")  # Adjust the size and pad as needed
            # im = ax.imshow(df_data["c"].values.reshape(-1, 1), cmap=cmap, vmin=vmin, vmax=vmax)
            # plt.colorbar(im, cax=cax)

        # Legend
        if show_legend:
            if number_of_classes > 1:
                ax.legend(**_legend_kws)
                if bool(legend_title):
                    ax.legend_.set_title(legend_title, prop=_legend_title_kws)
        # Title
        if title is not None:
            axes[-1].set_title(title, **_title_kws)

        # Background color
        for ax_query in axes:
            ax_query.set_facecolor(background_color)

        return fig, axes

# Load abundances (Gomez and Espinoza et al. 2017)
X = pd.read_csv("https://github.com/jolespin/projects/raw/main/supragingival_plaque_microbiome/16S_amplicons/Data/X.tsv.gz", 
                sep="\t",
                index_col=0,
                compression="gzip",
)


Y = pd.read_csv("https://github.com/jolespin/projects/raw/main/supragingival_plaque_microbiome/16S_amplicons/Data/Y.tsv.gz", 
                sep="\t",
                index_col=0,
                compression="gzip",
)

classes = Y["Caries_enamel"].loc[X.index]
c = pd.Series(classes.map(lambda x: {True:"blue", False:"green"}[x == "NO"]))
sample_labels = pd.Index(X.sum(axis=1).sort_values().index[:4].tolist())
sample_labels = pd.Series(sample_labels.map(lambda x: x.split("_")[0]), sample_labels)
# fig, axes = plot_compositions(X, s=28,logscale=False, sample_labels=sample_labels, class_colors={"NO":"black", "YES":"red"}, classes=classes, title="Caries")
fig, axes = plot_compositions(X,c=Y.loc[X.index,"age (months)"], s=28,logscale=False, sample_labels=sample_labels, title="Caries")

Here's the resulting figure: enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
O.rka
  • 29,847
  • 68
  • 194
  • 309
  • 1
    You can do all of what you did, or you can just use [`sns.jointplot`](https://seaborn.pydata.org/generated/seaborn.jointplot.html), which also adds a legend. `seaborn` is a high-level API for `matplotlib`. – Trenton McKinney Aug 02 '23 at 02:53
  • @TrentonMcKinney yea good call on jointplot. How can I add a colorbar tho? – O.rka Aug 02 '23 at 06:59
  • 1
    [How to add a colorbar to the side of a kde jointplot](https://stackoverflow.com/q/60845764/7758804) – Trenton McKinney Aug 02 '23 at 08:06
  • 1
    @TrentonMcKinney that looks like the one. I'll clean up the implementation. – O.rka Aug 02 '23 at 15:35

0 Answers0