1

I am trying to reproduce a figure from a paper in matplotlib, shown below. Basically, each cell has a percentage in it, and the higher the percentage, the darker the background of the cell:

enter image description here

The code below produces something similar, but each cell is a square pixel and I would like for them to be flatter rectangles rather than squares, as in the above image. How can I achieve this with matplotlib?

import numpy as np
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1)
plt.yticks(np.arange(10), class_names)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j, i, format(table[i,j], '.2f'),
             horizontalalignment="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

Here's what this code above produces: enter image description here

I suspect that changing the aspect and the extent for imshow may help me here. I don't fully understand how this works, but here's what I've tried:

plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])

This produces the following: enter image description here

I realise that I also need to add the borders between the cells, remove the tick marks, and change the values to percentages rather than decimals, and I am confident that I will be able to do this by myself, but if you want to help me out with that too then please feel free!

timleathart
  • 520
  • 1
  • 5
  • 20
  • 1
    Maybe this stackoverflow link [Imshow: extent and aspect](https://stackoverflow.com/questions/13384653/imshow-extent-and-aspect) answers your question – Spezi94 Oct 19 '17 at 08:46
  • 1
    @Spezi94 Thanks, I saw this answer and tried it but I can't work out how the text placement works after the aspect & extent have been changed. I'll update the question to reflect this – timleathart Oct 19 '17 at 08:48
  • 1
    Using `extent=[0,14,0,10]` gives a nice looking figure (for me at least). Though I also have't figured out how to place the text..... – DavidG Oct 19 '17 at 08:54
  • @DavidG Looks nice, thanks! I've just noticed that the order of the labels is reversed, and this can be fixed by using `extent=[0,14,10,0]` instead. – timleathart Oct 19 '17 at 08:59

2 Answers2

4

You will get non-square pixels when using aspect="auto" in the imshow call:

import numpy as np
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect="auto")
plt.yticks(np.arange(10), class_names)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j, i, format(table[i,j], '.2f'),
             ha="center", va="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

enter image description here

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
2

After lots of experimentation, I figured out how the extent works and how this effects the coordinates for text later on. I also added the borders etc, and this code generates a pretty good replica of the original style!

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()

plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])
plt.yticks(np.arange(10)+0.5, class_names)
plt.xticks(np.arange(5)*2.8 + 1.4, ['1', '2', '3', '4', '5'])
ax = plt.axes()
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_ticks_position('none')

matplotlib.rcParams.update({'font.size': 14})

ax = plt.gca()

# Minor ticks
ax.set_xticks(np.arange(1, 5) * 2.8, minor=True);
ax.set_yticks(np.arange(1, 10, 1), minor=True);

# Gridlines based on minor ticks
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j*2.8+1.5, i+0.6, format(table[i,j], '.2f'),
             horizontalalignment="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

enter image description here

Thanks to DavidG and Spezi94 who helped with their comments!

timleathart
  • 520
  • 1
  • 5
  • 20