0

I want to plot a 3D tensor plane by plane using matplotlib in a loop. However, in this example, matplotlib keeps on adding colorbars to the figure:

data = np.random.rand(100,100,10)
for i in range(10):
    plt.imshow(np.squeeze(data[:, :, i]))
    plt.colorbar()
    plt.pause(2)
    print(i)

Caveat: I've seen some complicated answers to this simple question, which didn't work. The problem may sound simple, but I'm thinking there might be an easy (short) solution.

JP Maulion
  • 2,454
  • 1
  • 10
  • 13
mcExchange
  • 6,154
  • 12
  • 57
  • 103
  • First of all you need to make sure to use the same normalization for all images, else the colorbar will be wrong. Then you can decide to only add the colorbar once. Finally consider *updating* a single image, not plotting 10 images one on top of the other. I should think existing solutions to this problem work fine, but if you want to point to any that don't, I'm happy to take a look at them. – ImportanceOfBeingErnest Mar 11 '20 at 17:09
  • I would rather to have the colorbar to adapt to every new image, since every 2D map potentially has a different range. – mcExchange Mar 11 '20 at 17:17
  • In that case, just use [the official example](https://matplotlib.org/gallery/animation/animation_demo.html#sphx-glr-gallery-animation-animation-demo-py)? – ImportanceOfBeingErnest Mar 11 '20 at 22:33
  • the important part for me is the colorbar. which should adapt to every new 2D map's data range on the fly. Anyway, I've given up the search for a 'short solution'. I used this functionality a million times in Matlab and thought it would be similar using matplotlib – mcExchange Mar 12 '20 at 08:23

3 Answers3

2

The easy solution

Clear the figure in each loop run.

import numpy as np
import matplotlib.pyplot as plt

data = np.random.rand(100,100,10) * np.linspace(1,7,10)

fig = plt.figure()

for i in range(10):
    plt.clf()
    plt.imshow(np.squeeze(data[:, :, i]))
    plt.colorbar()
    plt.pause(2)

plt.show()

The efficient solution

Use the same image and just update the data. Also use a FuncAnimation instead of a loop to run everything within the GUI event loop.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

data = np.random.rand(100,100,10) * np.linspace(1,7,10)

fig, ax = plt.subplots()
im = ax.imshow(np.squeeze(data[:, :, 0]))
cbar = fig.colorbar(im, ax=ax)

def update(i):
    im.set_data(data[:, :, i])
    im.autoscale()

ani = FuncAnimation(fig, update, frames=data.shape[2], interval=2000)
plt.show()
ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
1

So here is a solution. Unfortunately it is not short at all. If someone knows how to make this less complicated, feel free to post another answer.

This is slightly modified version of this answer

import matplotlib.pyplot as plt
import numpy as np


def visualize_tensor(data, delay=0.5):
    """ data must be 3 dimensional array and
    have format:
    [height x width x channels]"""
    assert(np.ndim(data) == 3)

    # Get number of channels from last dimension
    num_channels = np.shape(data)[-1]

    # Plot data of first channel
    fig = plt.figure()
    ax = fig.add_subplot(111)
    data_first_channel = data[:, :, 0]
    plot = ax.imshow(data_first_channel)

    # Create colorbar
    cbar = plt.colorbar(plot)
    plt.show(block=False)

    # Iterate over all channels
    for i in range(num_channels):
        print(f"channel = {i}")
        data_nth_channel = np.squeeze(data[:, :, i])
        plot.set_data(data_nth_channel)
        plot.autoscale()
        vmin = np.min(data_nth_channel.view())  # get minimum of nth channel
        vmax = np.max(data_nth_channel.view())  # get maximum of nth channel
        cbar.set_clim(vmin=vmin, vmax=vmax)     
        cbar_ticks = np.linspace(vmin, vmax, num=11, endpoint=True)
        cbar.set_ticks(cbar_ticks)
        cbar.draw_all()
        plt.draw()
        plt.pause(delay)

Example execution:

data = np.random.rand(20,20,10)
visualize_tensor(data)

Update: Using plot.autoscale() forces the colorbar to adapt dynamically, see this answer

mcExchange
  • 6,154
  • 12
  • 57
  • 103
1

This question intrigued me as hacking at matplotlib is somewhat my hobby. Next to the solution posed by @mcExchange one could use this

from matplotlib.pyplot import subplots 
import numpy as np

%matplotlib notebook
d = np.random.rand(10, 10)
fig, ax = subplots(figsize = (2,2))
# create mappable
h = ax.imshow(d)
# create colorbar
cb = fig.colorbar(h)
# show non-blocking
fig.show(0)
for i in range(100):
    # generate new data
    h.set_data(np.random.randn(*d.shape) + 1)
    h.autoscale()
    # flush events update time 
    ax.set_title(f't = {i}')
    fig.canvas.draw(); fig.canvas.flush_events(); 

How did I get this solution?

The docs state that colorbar.update_normal only updates if the norm on the mappable is different than before. Setting the data doesn't change this. As such manually function have to be called to register this update. Behind the scene the following happens:

    # rescale data for cb trigger
    h.norm.autoscale(h._A) #h._A is the representation of the data
    # update mappable
    h.colorbar.update_normal(h.colorbar.mappable)
cvanelteren
  • 1,633
  • 9
  • 16
  • also works. Though I find it a little harder to read. What's a `mappable` is this a synonym for `data`? Also `h._A` and `h.norm` look a bit cryptic to me. – mcExchange Mar 12 '20 at 15:15
  • See the edit. The cryptic variables are the private representations of the data. Norm is the function that will normalize the data to fit in the clim range. – cvanelteren Mar 13 '20 at 08:55