I'm trying to plot a 3D-scatterplot using matplotlib but for some reason the output doesn't show up with the legends. I want the legend to be one of my dataframe columns (category).
fig = plt.figure(figsize=(12,8))
ax = Axes3D(fig)
color_dict = { 'Beauty':'red', 'Kids':'green', 'Food':'blue', 'Jewelry':'yellow')}
names = df['category'].unique()
for s in names:
sc = ax.scatter(embedding[:,0], embedding[:,1], embedding[:,2], s=40,
color=[color_dict[i] for i in df_brands['category']], marker='x', label=names[s])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.legend()
# plt.show()
plt.legend(*sc.legend_elements(), bbox_to_anchor=(1.05, 1), loc=2)
As can be seen in the output below, the legend is showing like all categories are the same color (red) in the top right corner. (I simplify the code to 4 colours, ignore the fact there are more colours in the plot).
Any help would be appreciated. Thanks!