0

I currently have a graph of exoplanets' insolation vs density, with different colors contributing to different orbit periods. I have the color situation figured out, I'm just confused on how to set the legend. Here's what I have.

plt.figure(figsize = (9,7))
plt.title('Insolation vs Density', fontsize = 24, 
fontweight='bold')
plt.xlabel('Density [g/cm$^3$]', fontsize = 16)
plt.ylabel('Insolation [Earth Flux]', fontsize=16)
plt.xscale('log')
plt.yscale('log')
x = data['Density [g/cm**3]']
y = data['Insolation [Earth Flux]']
z = data['Orbital Period']

def pltcolor(lst):
    cols=[]
    for i in data['Orbital Period']:
        if i <= 3:
            cols.append('mediumturquoise'),
        elif i >= 20 :
            cols.append('blue'),
        else:
            cols.append('crimson')
    return cols
cols=pltcolor(z)

plt.scatter(x=x,y=y,c=cols)
plt.scatter(circum_data['Density [g/cm**3]'],circum_data['Insolation [Earth Flux]'], color = 'fuchsia', label = 
Circumbinary Planets')
plt.legend();

2 Answers2

0

You could have your legend labels (let's call it "labels") in a vector of the same length as your "x" vector, and just do:

plt.legend(labels)

You can also do it differently (I think that's what you're trying to do): you could just plot each group in a loop and set the label. It would be something like:

for i,group in enumerate(groups):
    plt.scatter(x[group],y[group],label=group_names[i])
plt.legend()
Nakor
  • 1,484
  • 2
  • 13
  • 23
0

From my understanding, you need to call plt.scatter for each group. For reference take a look at this question. Right now you are figuring out what each of the data points colors should be and then assigning them a color in cols. Then you are calling plt.scatter once and it plots all the points and assigns colors accordingly. However, matplotlib still thinks all those points are from the same group. So when you call plt.legend() it only gives a single label.

I have tried to make a work around for you using your code. It was a little bit tricky because you removed the data from your example (understandably). I am assuming that your data is a list, and so created some fake data to test my approach.

So my approach is as follows: go through your data and if your z data lies inside a specific range you assign it to a new array. Once you have processed all your data for that group (z range) plot it. Then repeat this for each of the groups. I have attached a sample of what I am thinking below. There are possibly cleaner ways of doing this. However, the overall approach is the same. Try and plot each group individually.

import matplotlib.pyplot as plt
import math

# Fake data I created
data = {}
data['Density [g/cm**3]'] = [10,15, 31, 24,55]
data['Insolation [Earth Flux]'] = [10,15,8,4,55]
data['Orbital Period'] = [10,15,3,2,55]

circum_data = {}
circum_data['Density [g/cm**3]'] = [10,15,7,5,55]
circum_data['Insolation [Earth Flux]'] = [10,15,4,3,55]

# ----- Your code------
plt.figure(figsize = (9,7))
plt.title('Insolation vs Density', fontsize = 24, fontweight='bold')
plt.xlabel('Density [g/cm$^3$]', fontsize = 16)
plt.ylabel('Insolation [Earth Flux]', fontsize=16)
plt.xscale('log')
plt.yscale('log')
x = data['Density [g/cm**3]']
y = data['Insolation [Earth Flux]']
z = data['Orbital Period']
# -----------------

# Created the ranges you want
distances_max = [3, 20, math.inf]
distances_min = [-1*math.inf, 3, 20]
# Select you colors
colors = ['mediumturquoise', 'blue', 'crimson']
# The legend names you want
names = ['name1', 'name2', 'name3']

# For each of the ranges
for i in range(len(names)):
    # Create a new group of data
    col = []
    data_x = []
    data_y = []
    # Go through your data and put it in the group if it is inside the range
    for (xi, yi, zi) in zip(x, y, z): 
        if distances_min[i] < zi <= distances_max[i]:
            col.append(colors[i])
            data_x.append(xi) 
            data_y.append(yi) 

    # Plot the group of data
    plt.scatter(x=data_x,y=data_y,c=colors[i], label=names[i])   


# plt.scatter(circum_data['Density [g/cm**3]'],
#             circum_data['Insolation [Earth Flux]'],
#             color = 'fuchsia',
#             label ='Circumbinary Planets')
plt.legend()
plt.show()

Running this code produces the following outputs where name1, name2, name3 are defined in the names list.

Example graph

I hope this helped. Good luck!

Watchdog101
  • 700
  • 6
  • 19