0

I am trying to replicate the following plot but with a different set of data:

Plot

My current plot has everything you see except the legend in the top right corner. I am having a hard time figuring out how I am supposed to add this in with my current code:

fig = plt.figure()

plt.subplot(3, 1, 1)
plt.title('Task Switches and Avg Task Switches by Timestep', fontsize=10)
plt.ylabel('Task Switches', fontsize=9)
plt.xlim(-35, timestep_num + 35)
plt.xticks(np.arange(0, timestep_num+1, 50), fontsize=-1, color='white')
plt.yticks(np.arange(0, 61, 20), fontsize=6)
plt.plot([stepsum_list[i][6] for i in range(len(stepsum_list))], color='royalblue', 
linewidth=0.7, linestyle='', marker='.', markersize=1)
plt.plot([stepsum_list[i][6]/(i+1) for i in range(len(stepsum_list))], color='limegreen', 
linewidth=0.6,)

plt.subplot(3, 1, 2)
plt.title('Task Demand per Timestep by Task', fontsize=10)
plt.ylabel('Task Demand', fontsize=9)
plt.xlim(-35, timestep_num + 35)
plt.xticks(np.arange(0, timestep_num+1, 50), fontsize=-1, color='white')
plt.yticks(np.arange(0, 6, 1), fontsize=6)
plt.plot([stepdem_list[i][1] for i in range(len(stepdem_list))], color='darkorange', 
linewidth=0.7, linestyle='', marker='.', markersize=1)
plt.plot([stepdem_list[i][2] for i in range(len(stepdem_list))], color='yellowgreen', 
linewidth=0.7, linestyle='', marker='.', markersize=1)
plt.plot([stepdem_list[i][3] for i in range(len(stepdem_list))], color='purple', 
linewidth=0.7, linestyle='', marker='.', markersize=1)
plt.plot([stepdem_list[i][4] for i in range(len(stepdem_list))], color='blue', linewidth=0.7, 
linestyle='', marker='.', markersize=1)

plt.subplot(3, 1, 3)
plt.title('Target and Tracker Movement',fontsize=10)
plt.ylabel('Movement', fontsize=9)
plt.xlabel('Timesteps', fontsize=9)
plt.xlim(-35, timestep_num + 35)
plt.xticks(np.arange(0, timestep_num+1, 50), fontsize=8)
plt.yticks(np.arange(-10, 11, 10), fontsize=6)
plt.plot([stepsum_list[i][4] for i in range(len(stepsum_list))], color='blue', linewidth=.5)
plt.plot([stepsum_list[i][2] for i in range(len(stepsum_list))], color='red', linewidth=.5)

fig.align_labels()
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.4, hspace=0.4)
plt.savefig('prog02_output.png')
plt.show

I apologize for all of the repetitive code, I'm brand new to Python and this is my first time making a plot so I don't know all of the tricks just yet. I have found the function figlegend(), but I'm confused if this is what I am going to want to use, and if so how the parameters are working. Placing the legend in the correct spot (aligned with the top subplot) is also something I am trying to do, but can't seem to figure out.

I'm not asking anyone to write any code or rewrite what I have. Just for someone to point me in the right direction, whether that be explaining a function and what parameters it can take, or what might need to be changed in my current code to use figlegend().

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
WeekendJedi
  • 67
  • 10
  • For people to try and provide an answer that works for your example, it would be helpful if you included some mock data that would demonstrate the result you're getting. – Grismar Oct 27 '21 at 22:43
  • First, you'll want to add a `label` keyword argument to everything you're plotting. Then look into [`Axes.get_legend_handles_labels`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.get_legend_handles_labels.html) – AJ Biffl Oct 28 '21 at 00:08

2 Answers2

0

The way I plot legends with my plots in Matplotlib is via the Axes.legend() function, shown below: Plot

The source code is

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot([0,1,2],[2,1,0], c='r', label='Plot 1')
ax.plot([0,1,2],[0,1,2], c='b', label='Plot 2')
ax.legend()
plt.show()

After you add labels to each of the data traces in your plot via the label keyword argument, then you can add a legend to the figure with

plt.gca().legend()
dsillman2000
  • 976
  • 1
  • 8
  • 20
0

If your try to put more 3 variables, you can do this:

    lns1 = plt.plot(x,y)
    lns2 = plt.plot(x2,y)
    lns3 = plt.plot(x3,y)
    lns = lns1+lns2+lns3+lns4
    
    labs = [l.get_label() for l in lns]
    
    plt.figure()

When you go to plot, use this:

    ax.legend(lns,labs,loc='lower center')
D.L
  • 4,339
  • 5
  • 22
  • 45