3

I'm generating a figure that has 4 curves (for example), divided into 2 types - Type1 and Type2 (2 curves of each type). I'm drawing Type1 as a solid line while Type2 is dashed. To not overload the figure, I want to add a text somewhere in the figure that explains that the solid lines are Type1 and the dashed lines are Type2, and not to enter this on every legend entry like in the following example:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple


x = np.arange(1,10)
a_1 = 2 * x
a_2 = 3 * x
b_1 = 5 * x
b_2 = 6 * x

p1, = plt.plot(x, a_1,)
p2, = plt.plot(x, a_2, linestyle='--')
p3, = plt.plot(x, b_1)
p4, = plt.plot(x, b_2, linestyle='--')

plt.legend([(p1, p2), (p3, p4)], ['A Type1/Type2', 'B Type1/Type2'], numpoints=1,handler_map={tuple: HandlerTuple(ndivide=None)}, handlelength=3)

plt.show()

The result is:

Try1

What I would like is something like this:

What I want

Where I removed the Type1/Type2 from each legend and added it with black color somewhere appropriate in the figure (marked by a red circle).

Can anybody help?

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
MRm
  • 517
  • 2
  • 14

1 Answers1

2
  • I think it's easiest to allow the plotting API to handle the legend, compared to manually constructing it, which means properly labeling the data to feed into the API.
  • In the following example, the data is loaded into a dict, where the values have been provided a category and type.
    • ['A']*len(a_1) creates a list of labels based on the length of a given array
    • ['A']*len(a_1) + ['A']*len(a_2) combines multiple lists into a single list
    • 'x': np.concatenate((x, x, x, x)) ensures that each value in vals is plotted with the correct x value.
  • seaborn.lineplot, which is a high-level API for matplotlib, can handle loading data directly from the dict, where the hue and style parameters can be used.
import numpy as np
import seaborn as sns

# load the data from the OP in to the dict
data = {'x': np.concatenate((x, x, x, x)),
        'vals': np.concatenate((a_1, a_2, b_1, b_2)),
        'cat': ['A']*len(a_1) + ['A']*len(a_2) + ['B']*len(b_1) + ['B']*len(b_2),
        'type': ['T1']*len(a_1) + ['T2']*len(a_2) + ['T1']*len(b_1) + ['T2']*len(b_2)}

# plot the data
p = sns.lineplot(data=data, x='x', y='vals', hue='cat', style='type')

# move the legend
p.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

enter image description here

sns.relplot(data=data, kind='line', x='x', y='vals', hue='cat', style='type')

enter image description here

Manual Legend Creation

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158