2

How to set a shared secondary axes using subplots in matplotlib.

Here is the minimal code to display the issue:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def countour_every(ax, every, x_data, y_data,
                   color='black', linestyle='-', marker='o', **kwargs):
    """Draw a line with countour marks at each every points"""
    line, = ax.plot(x_data, y_data, linestyle)
    return line


def prettify_axes(ax, data):
    """Makes my plot pretty"""

    if 'title' in data:
        ax.set_title(data['title'])

    if 'y_lim' in data:
        ax.set_ylim(data['y_lim'])

    if 'x_lim' in data:
        ax.set_xlim(data['x_lim'])

    # Draw legend only if labels were set (HOW TO DO IT?)
    # if ax("has_some_label_set"):
    ax.legend(loc='upper right', prop={'size': 6})

    ax.title.set_fontsize(7)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.xaxis.set_tick_params(direction='in')
    ax.xaxis.label.set_size(7)

    ax.yaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(direction='in')
    ax.yaxis.label.set_size(7)


def prettify_second_axes(ax):
    ax.yaxis.set_tick_params(labelsize=7)
    ax.yaxis.set_tick_params(labelcolor='red')
    ax.yaxis.label.set_size(7)


def compare_plot(ax, data):
    line1 = countour_every(ax, 10, **data[0])
    if 'label' in data[0]:
        line1.set_label(data[0]['label'])

    line2 = countour_every(ax, 10, **data[1])
    if 'label' in data[1]:
        line2.set_label(data[1]['label'])

    ax2 = ax.twinx()
    line3 = ax.plot(
            data[0]['x_data'],
            data[0]['y_data']-data[1]['y_data'], '-',
            color='red', alpha=.2, zorder=1)

    prettify_axes(ax, data[0])
    prettify_second_axes(ax2)


d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)

compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])

fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')

fig.savefig('demo.png', dpi=300)

That generates the following image

Shared axes issue

We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.

Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).

Lin
  • 1,145
  • 11
  • 28
  • Concerning the secondary axis not scaling, this is because there is nothing in that axes that would allow to scale it. Supposedly you want to plot the red curve to `ax2` instead of `ax`. Concerning repetition of the secondary axes labels, do you want to share those axes? Or do you want to remove the inner labels? Or both? – ImportanceOfBeingErnest Nov 27 '18 at 00:31
  • Damn, didn't noticed that. Thanks. I want to share the axes, that would imply in both, removing inner labels and using the same scaling for sake of comparison. – Lin Nov 27 '18 at 00:39
  • changing the `ax` to `ax2` issue in the `compare_plot` function ,fixed the non scaling issue. Now I'm left with not-shared axis and repeated labels. – Lin Nov 27 '18 at 00:41

1 Answers1

4

You will need to share the twin axes manually and also remove the ticklabels

def compare_plot(ax, data):
    # ...
    ax2 = ax.twinx()
    # ...
    return ax2

sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])

for sax in [sax2, sax3, sax4]:
    sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
    sax.yaxis.set_tick_params(labelright=False)

enter image description here

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
  • while this works, it will pretty hard to implement in my code since the number of axes is variable. I will try to figure out how to handle this with the new knowlege. Thanks! – Lin Nov 27 '18 at 01:37
  • My advise would be to create all the twin axes directly after creating the `axes` array. – ImportanceOfBeingErnest Nov 27 '18 at 01:41
  • I'd use the `sharex` and `sharey`, how the approach would work? I have put my code for review in https://codereview.stackexchange.com/questions/208492/modularizing-matplolib-graphing-based-on-a-data-dictionary . There I explain all cases. I'm really getting mad trying to create simple functions to do my drawings from a data structure. – Lin Nov 27 '18 at 02:02
  • You would still need to use the same approach as here, but you would then have an array of `axes` and another one `axes2` to use. – ImportanceOfBeingErnest Nov 27 '18 at 02:23