I have a Subplots
with Confusion Matrixes which are presented with HeatMap
.
I would like to adjust the graph to be more readible and do things like:
1) Add one big title above columns 'Targets'
2) Add one big Ylabel 'Predictions'
3) for each column have only one big legend, since they are showing the same thing
4 ) for each column add column names
['Train CM', 'Train Norm CM', 'Validation CM', 'Validation Norm CM']
and row names [f'Epoch {i}' for i in range(n_epoch)]
. I did like in here
but only work for columns and not for rows, I dont know why.
My code:
cols = ['Train CM', 'Train Norm CM', 'Validation CM', 'Validation Norm CM']
rows = [f'Epoch {i}' for i in range(n_epoch)]
f, axes = plt.subplots(nrows = n_epoch, ncols = 4, figsize=(40, 30))
for ax, col in zip(axes [0], cols):
ax.set_title(col, size='large')
for ax, row in zip(axes[:,0], rows):
ax.set_ylabel(row, rotation=0, size='large')
f.tight_layout()
for e in range(n_epoch):
for c in range(4):
# take conf matrix from lists cm_Train or cm_Validation of ConfusionMatrix() objects
if c == 0:
cm = np.transpose(np.array([list(item.values()) for item in cm_Train[e].matrix.values()]))
elif c == 1:
cm = np.transpose(np.array([list(item.values()) for item in cm_Train[e].normalized_matrix.values()]))
elif c == 2:
cm = np.transpose(np.array([list(item.values()) for item in cm_Validation[e].matrix.values()]))
else:
cm = np.transpose(np.array([list(item.values()) for item in cm_Validation[e].normalized_matrix.values()]))
sns.heatmap(cm, annot=True, fmt='g', ax = axes[e, c], linewidths=.3)