2

I'm facing a problem in showing the legend in the correct format using matplotlib.

EDIT: I have 4 subplots in a figure in 2 by 2 format and I want legend only on the first subplot which has two lines plotted on it. The legend that I got using the code attached below contained endless entries and extended vertically throughout the figure. When I use the same code using linspace to generate fake data the legend works absolutely fine.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import os

#------------------set default directory, import data and create column output vectors---------------------------#

path="C:/Users/Pacman/Data files"
os.chdir(path)
data =np.genfromtxt('vrp.txt')

x=np.array([data[:,][:,0]])

y1=np.array([data[:,][:,6]])
y2=np.array([data[:,][:,7]])
y3=np.array([data[:,][:,9]])
y4=np.array([data[:,][:,11]])
y5=np.array([data[:,][:,10]])

nrows=2
ncols=2
tick_l=6   #length of ticks
fs_axis=16 #font size of axis labels


plt.rcParams['axes.linewidth'] = 2         #Sets global line width of all the axis
plt.rcParams['xtick.labelsize']=14         #Sets global font size for x-axis labels
plt.rcParams['ytick.labelsize']=14         #Sets global font size for y-axis labels



plt.subplot(nrows, ncols, 1)

ax=plt.subplot(nrows, ncols, 1)
l1=plt.plot(x, y2, 'yo',label='Flow rate-fan')
l2=plt.plot(x,y3,'ro',label='Flow rate-discharge')
plt.title('(a)')
plt.ylabel('Flow rate ($m^3 s^{-1}$)',fontsize=fs_axis)
plt.xlabel('Rupture Position (ft)',fontsize=fs_axis)

# This part is not working
plt.legend(loc='upper right', fontsize='x-large')

#Same code for rest of the subplots

I tried to implement a fix suggested in the following link, however, could not make it work: how do I make a single legend for many subplots with matplotlib?

Any help in this regard will be highly appreciated.

Community
  • 1
  • 1
SAkht312
  • 55
  • 3
  • 8

2 Answers2

1

If I understand correctly, you need to tell plt.legend what to put as legends... at this point it is being loaded empty. What you get must be from another source. I have quickly the following, and of course when I run fig.legend as you do I get nothing.

import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.7])
ax2 = fig.add_axes([0.55, 0.1, 0.4, 0.7])

x = np.arange(0.0, 2.0, 0.02)
y1 = np.sin(2*np.pi*x)
y2 = np.exp(-x)
l1, l2 = ax1.plot(x, y1, 'rs-', x, y2, 'go')

y3 = np.sin(4*np.pi*x)
y4 = np.exp(-2*x)
l3, l4 = ax2.plot(x, y3, 'yd-', x, y4, 'k^')

fig.legend(loc='upper right', fontsize='x-large')

#fig.legend((l1, l2), ('Line 1', 'Line 2'), 'upper left')
#fig.legend((l3, l4), ('Line 3', 'Line 4'), 'upper right')
plt.show()

I'd suggest doing one by one, and then applying for all.

Joseph
  • 71
  • 4
  • Thank you for your comment. I think I did not phrase my problem properly. The legend that I get using my code above contains every data point as the title. Consequently, the legend shows for every point in the plot and thus it extends throughout the whole screen and beyond vertically. I tried the same code but generated fake data using linspace and it worked. – SAkht312 Jun 15 '16 at 19:19
0

It is useful to work with the axes directly (ax in your case) when when working with subplots. So if you set up two plots in a figure and only wish to have a legend in your second plot:

t = np.linspace(0, 10, 100)

plt.figure()

ax1 = plt.subplot(2, 1, 1)
ax1.plot(t, t * t)

ax2 = plt.subplot(2, 1, 2)
ax2.plot(t, t * t * t)
ax2.legend('Cubic Function')

Note that when creating the legend, I am doing so on ax2 as opposed to plt. If you wish to create a second legend for the first subplot, you can do so in the same way but on ax1.

Sajjan Singh
  • 2,523
  • 2
  • 27
  • 34
  • Thank you for your comment. I think I did not phrase my problem properly. The legend that I get using my code above contains every data point as the title. Consequently, the legend shows for every point in the plot and thus it extends throughout the whole screen and beyond vertically. I tried the same code but generated fake data using linspace and it worked. – SAkht312 Jun 15 '16 at 19:17
  • I would first check what the shapes of your arrays `x`, `y2`, and `y3` are in case they are not simple 1D arrays – Sajjan Singh Jun 15 '16 at 21:20
  • Thanks, this worked by removing the outer square brackets from all the vectors such as y1=np.array([data[:,][:,6]]) – SAkht312 Jun 16 '16 at 21:45