2

I am seeking a function that would work as follows:

import pandas as pd

def plot_df(df: pd.DataFrame, x_column: str, columns: List[List[str]]):
  """Plot DataFrame using `x_column` on the x-axis and `len(columns)` different
  y-axes where the axis numbered `i` is calibrated to render the columns in `columns[i]`.

  Important: only 1 legend exists for the plot
  Important: each column has a distinct color
    If you wonder what colors axes should have, they can assume one of the line colors and just have a label associated (e.g., one axis for price, another for returns, another for growth)
"""

As an example, for a DataFrame with the columns time, price1, price2, returns, growth you could call it like so:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

This would result in a chart with:

  • 3 y-axes
  • y-axis for price1 and price2 would be shared
  • each axis would be calibrated independently

I've looked at a couple of solutions which don't work for this.

Possible solution #1:

https://matplotlib.org/stable/gallery/ticks_and_spines/multiple_yaxis_with_spines.html

In this example, each axis can only accommodate one column, so it's wrong. In particular in the following code, each axis has one series:

p1, = ax.plot([0, 1, 2], [0, 1, 2], "b-", label="Density")
p2, = twin1.plot([0, 1, 2], [0, 3, 2], "r-", label="Temperature")
p3, = twin2.plot([0, 1, 2], [50, 30, 15], "g-", label="Velocity")

If you add another plot to this axis, the same color ends up duplicated:

enter image description here

Moreover, this version does not use the built in plot() function of data frames.

Possible solution #2:

PANDAS plot multiple Y axes

In this example, also each axis can only accommodate a single column from the data frame.

Possible solution #3:

Try to adapt solution 2. by changing df.A to df[['A', 'B']] but this beautifully doesn't work since it results in these 2 columns sharing the same axis color as well as multiple legends popping up.

So - asking pandas/matplotlib experts if you can figure out how to overcome this!

Peteris
  • 3,548
  • 4
  • 28
  • 44
  • 1
    Do you always have those 3 columns? Can you explain how solution #1 is not working and also provide a small (dummy) example – mozway Aug 27 '21 at 14:29
  • 3 columns are just for illustration, it should work in the general case. I will amend to illustrate how solution 1 does not work. – Peteris Aug 27 '21 at 14:30
  • why can't you just add a second plot to `ax` in the example case? So `p1, = ax.plot(...); p2, = ax.plot(...); p3, = twin1.plot(...); p4, = twin2.plot(...)` – tmdavison Aug 27 '21 at 15:12
  • @tmdavison you can but the line will end up sharing the same color as the other plot on that axis. I added a picture for clarity – Peteris Aug 27 '21 at 15:29
  • well, no, you can control the colour of the line and make it anything you want (that's what the `"r-"`, `"b-"`, etc. is doing) – tmdavison Aug 27 '21 at 15:31
  • or you could keep them the same colour, and change the linestyle to dashed or something: `"r--"` – tmdavison Aug 27 '21 at 15:32

3 Answers3

2

You can chain axes from df to df.

import pandas as pd
import numpy as np

Create the data and put it in a df.

x=np.arange(0,2*np.pi,0.01)
b=np.sin(x)
c=np.cos(x)*10
d=np.sin(x+np.pi/4)*100
e=np.sin(x+np.pi/3)*50
df = pd.DataFrame({'x':x,'y1':b,'y2':c,'y3':d,'y4':e})

Define first plot and subsequent axes

ax1 = df.plot(x='x',y='y1',legend=None,color='black',figsize=(10,8))
ax2 = ax1.twinx()
ax2.tick_params(axis='y', labelcolor='r')

ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes',1.15))
ax3.tick_params(axis='y', labelcolor='g')

ax4=ax1.twinx()
ax4.spines['right'].set_position(('axes',1.30))
ax4.tick_params(axis='y', labelcolor='b')

You can add as many as you want...

Plot the remainder.

df.plot(x='x',y='y2',ax=ax2,color='r',legend=None)
df.plot(x='x',y='y3',ax=ax3,color='g',legend=None)
df.plot(x='x',y='y4',ax=ax4,color='b',legend=None)

Results:

DanDevost
  • 41
  • 5
1

I assume you are working with a dataframe like this:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

Then a possible function could be:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)
    line_styles = ["-", "--", "-.", ":"]

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    for i, col in enumerate(columns):
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(i)[:3])
            handles.append(p)
        else:
            for j, sub_col in enumerate(col):
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(i)[:3], linestyle = line_styles[j])
                handles.append(p)

    ax.legend(handles = handles, frameon = True)

    for i, ax in enumerate(axes):
        ax.tick_params(axis = 'y', colors = cmap(i)[:3])
        if i == 0:
            ax.spines['left'].set_color(cmap(i)[:3])
            ax.spines['right'].set_visible(False)
        else:
            ax.spines['left'].set_visible(False)
            ax.spines['right'].set_color(cmap(i)[:3])

    plt.tight_layout()

    plt.show()

If you call the above function with:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

then you will get:

enter image description here

NOTES

  1. since price1 and price2 share the same y axis, they must share the same color too, so I have to use different linestyle to be able to distinguish them.
  2. the first element of columns list (['price1', 'price2'] in this case) is always drawn on the left axis, other elements on the right ones.
  3. if you wanted to set axis limits and labels, then you should pass these as additional parameters to plot_df.
Zephyr
  • 11,891
  • 53
  • 45
  • 80
  • Hi Zephyr, this is an awesome answer, however, in the question it does state needing to use different colors to distinguish lines. Given @tmdavison 's comment, I suspect that is possible and we don't need to resort to line styles. If we can figure out the colors, this is the winning answer! – Peteris Aug 27 '21 at 23:11
  • I was forced to use the style for an important reason: if each line is a different color, then I can no longer associate a line to the corresponding axis by color (if `price1` is blue and `price2` is orange, then which color should be the associated axis?). If I remove the color from the axes, then I no longer know on what scale I have to read the various lines – Zephyr Aug 27 '21 at 23:27
  • I mean: every y-axis needs a color, so that we can easily associate a line to the axis on which its values are to be read. If two lines are associated to the same y-axis (as `price1` and `price2` do), they have the same color, which is why I have to find a way to distinguish them. I have chosen to use the style because I think it is the cleanest way; alternatively you could use the thickness of the line, the markers, etc. – Zephyr Aug 27 '21 at 23:30
  • Oh I see, I think that should be solved by labeling the axis units instead, e.g., price, returns, growth. I will clarify in the question. – Peteris Aug 28 '21 at 12:01
  • 1
    Since the answer would be significantly different as concept I propose my idea in an other answer :-) – Zephyr Aug 28 '21 at 13:03
1

I assume you are working with a dataframe like this:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

Then a possible function could be:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    j = 0
    for i, col in enumerate(columns):
        ylabel = []
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(j)[:3])
            ylabel.append(col[0])
            handles.append(p)
            j += 1
        else:

            for sub_col in col:
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(j)[:3])
                ylabel.append(sub_col)
                handles.append(p)
                j += 1
        axes[i].set_ylabel(', '.join(ylabel))

    ax.legend(handles = handles, frameon = True)

    plt.tight_layout()

    plt.show()

If you call the above function with:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

then you will get:

enter image description here

NOTES

The first element of columns list (['price1', 'price2'] in this case) is always drawn on the left axis, other elements on the right ones.

Zephyr
  • 11,891
  • 53
  • 45
  • 80
  • Hey this does seem to have an issue with large numbers. Large numbers lead to an axis needing a multiplier (1e6, etc.) and those multipliers seem to be stacking on the chart itself rather than on top of the corresponding axes. – Peteris Aug 30 '21 at 21:33
  • You should open an other question about multipliers over multiple y axis, since it is a question qualitatively different from this one. Also, provide a dataset with your example of _large numbers_ – Zephyr Aug 30 '21 at 22:25