I want to plot 20 newsgroup dataset
topics using matplotlib. I have writted the code below:
markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"]
plt.rc('legend',**{'fontsize':10})
classes_to_visual = list(set(classes_to_visual))
C = len(classes_to_visual)
while True:
if C <= len(markers):
break
markers += markers
class_ids = dict(zip(classes_to_visual, range(C)))
if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
else:
codes, labels = doc_codes, doc_labels
X = np.r_[list(codes)]
tsne = TSNE(perplexity=40, n_components=2, init='pca', n_iter=5000)
np.set_printoptions(suppress=True)
X = tsne.fit_transform(X)
# NUM_COLORS = 20
plt.figure(figsize=(10, 10), facecolor='white')
for c in classes_to_visual:
idx = np.array(labels) == c
plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=1, marker=markers[class_ids[c]],
markersize=10, label=c)
legend = plt.legend(loc='upper right', shadow=True)
plt.savefig(save_file, format='eps', dpi=2000)
plt.show()
My only problem is that it plotted the 20 news group using 11
main colors and for distinguishing others it has markers
. but I want to have different distinguishable colors.
I tried various ways like defining colors:
ax.set_color_cycle([cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
though in some cases the colors was only different shade and did not differ that much!!
What can I do to get it work properly?
so far this is the result: