I am "annotating" many arrows of a certain color to add data to the graph (where events occurred). (example code). Is there a way to add it to the legend? One answer might be to add them manually as I show in the code below, but I guess it is always the last resort. What is the "right" way to do it? (bonus for also having a small arrow mark in the legend)
Here is an example, but really, the example is for ease of use, the question is just how to add label for line.axes.annotate
Here is a code which is almost identical to the one in the link: A function to add arrows to
def add_arrow(line, position=None, direction='right', size=15, color=None, length=None):
"""
add an arrow to a line.
line: Line2D object
position: x-position of the arrow. If None, mean of xdata is taken
direction: 'left' or 'right'
size: size of the arrow in fontsize points
color: if None, line color is taken.
length: the number of points in the graph the arrow will consider, leave None for automatic choice
"""
if color is None:
color = line.get_color()
xdata = line.get_xdata()
ydata = line.get_ydata()
if not length:
length = max(1, len(xdata) // 1500)
if position is None:
position = xdata.mean()
# find closest index
start_ind = np.argmin(np.absolute(xdata - position))
if direction == 'right':
end_ind = start_ind + length
else:
end_ind = start_ind - length
if end_ind == len(xdata):
print("skipped arrow, arrow should appear after the line")
else:
line.axes.annotate('',
xytext=(xdata[start_ind], ydata[start_ind]),
xy=(xdata[end_ind], ydata[end_ind]),
arrowprops=dict(
arrowstyle="Fancy,head_width=" + str(size / 150), color=color),
size=size
)
A function that uses it
def add_arrows(line, xs, direction='right', size=15, color=None, name=None):
if name:
if color is None:
color = line.get_color()
patch = mpatches.Patch(color=color, label=name, marker="->")
plt.legend(handles=[patch])
for x in xs:
add_arrow(line, x, color=color)
An example to what line is
x,y = [i for i in range(10000)], [i for i in range(10000)]
line = plt.plot(x, y, label="class days")[0]
add_arrows(line, (x,y))
plt.show()