So I'm using sci-kit learn to classify some data. I have 13 different class values/categorizes to classify the data to. Now I have been able to use cross validation and print the confusion matrix. However, it only shows the TP and FP etc without the classlabels, so I don't know which class is what. Below is my code and my output:
def classify_data(df, feature_cols, file):
nbr_folds = 5
RANDOM_STATE = 0
attributes = df.loc[:, feature_cols] # Also known as x
class_label = df['task'] # Class label, also known as y.
file.write("\nFeatures used: ")
for feature in feature_cols:
file.write(feature + ",")
print("Features used", feature_cols)
sampler = RandomOverSampler(random_state=RANDOM_STATE)
print("RandomForest")
file.write("\nRandomForest")
rfc = RandomForestClassifier(max_depth=2, random_state=RANDOM_STATE)
pipeline = make_pipeline(sampler, rfc)
class_label_predicted = cross_val_predict(pipeline, attributes, class_label, cv=nbr_folds)
conf_mat = confusion_matrix(class_label, class_label_predicted)
print(conf_mat)
accuracy = accuracy_score(class_label, class_label_predicted)
print("Rows classified: " + str(len(class_label_predicted)))
print("Accuracy: {0:.3f}%\n".format(accuracy * 100))
file.write("\nClassifier settings:" + str(pipeline) + "\n")
file.write("\nRows classified: " + str(len(class_label_predicted)))
file.write("\nAccuracy: {0:.3f}%\n".format(accuracy * 100))
file.writelines('\t'.join(str(j) for j in i) + '\n' for i in conf_mat)
#Output
Rows classified: 23504
Accuracy: 17.925%
0 372 46 88 5 73 0 536 44 317 0 200 127
0 501 29 85 0 136 0 655 9 154 0 172 67
0 97 141 78 1 56 0 336 37 429 0 435 198
0 135 74 416 5 37 0 507 19 323 0 128 164
0 247 72 145 12 64 0 424 21 296 0 304 223
0 190 41 36 0 178 0 984 29 196 0 111 43
0 218 13 71 7 52 0 917 139 177 0 111 103
0 215 30 84 3 71 0 1175 11 55 0 102 62
0 257 55 156 1 13 0 322 184 463 0 197 160
0 188 36 104 2 34 0 313 99 827 0 69 136
0 281 80 111 22 16 0 494 19 261 0 313 211
0 207 66 87 18 58 0 489 23 157 0 464 239
0 113 114 44 6 51 0 389 30 408 0 338 315
As you can see, you can't really know what column is what and the print is also "misaligned" so it's difficult to understand.
Is there a way to print the labels as well?