2

I'm struggling with FuncAnimation in matplotlib.animation, and I could not find out any examples or post looking similar to my problem (I mean, yes there is post concerning contourf used in funcAnimation but in those posts they succeed to delete the PathCollection object but in my case something is not working).

Context:

In a school project concerning One-vs-All notion (multiple binary classifiers), I want to implement functions to animate a figure having 3 Axes and containing multiple Line2D objects, an PathCollection object from scatter method and a QuadContourSet from contourf method.

Here a screen of how it looks like (obtained when I plot the data at the end of the training of the One-vs-All):

Representation of the static graph

Legend:

  • Left: Boundary decision in (Herbology)-(Defense against Dark Arts) plane,
  • Top right: Loss function of each binary classifiers,
  • Bottom right: Precision and Recall metrics of each classifiers.

Methods:

I am trying to have a animated version of the plot using FuncAnimation from matplotlib.Animation module. Animated version of the plot is a bonus feature of my project, then the animation part/core is made in functions, you can see a simplification below (a bare bones representation) :

def anim_visu(models, data):
    # initialization of the figure and object representing the data
    ...

    def f_anim():
        # Function which update the data at each frames

    visu = FuncAnimation(fig, f_anim, ...)

    return fig

[...]

if __name__ == "__main__":
    [...]
    if bool_dynamic: # activation of the dynamic visualization
        anim_visu(models, data)

And here a minimal workish example:

# =========================================================================== #
#                       |Importation des lib/packages|                        #
# =========================================================================== #
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec

dct_palet = {"C1":"dodgerblue",
             "C2":"red",
             "C3":"green",
             "C4":"goldenrod"}
fps = 15

# =========================================================================== #
#                        | Definition des fonctions |                         #
# =========================================================================== #

def one_vs_all_prediction(classifiers:list, X:np.array) -> np.array:
    """
    ... Docstring ...
    """
    preds = np.zeros((X.shape[0],1))

    for clf in classifiers:
        tmp = clf.predict(X)
        mask = preds == 0
        preds[mask] = tmp[mask]
    
    return preds


def one_vs_all_class_onehot(class_pred:np.array):
    """
    ... Docstring ...
    """
    house = {"C1":1., "C2":2., "C3":3., "C4":4.}
    onehot_pred = np.chararray((class_pred.shape[0],1), itemsize=2)
    
    for key, item in house.items():
        mask = class_pred == item
        onehot_pred[mask] = key
    
    return onehot_pred


def do_animation(clfs:list, data:np.ndarray):
    """ Core function for the animated vizualisation.
    The function defines all the x/y_labels, the titles.
    """

    global idx, cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
        met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
        met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
        boundary, axes, \
        l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
        l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4
    
    plt.style.use('seaborn-pastel')
    
    # -- Declaring the figure and the axes -- #
    fig = plt.figure(figsize=(15,9.5))
    gs = GridSpec(2, 2, figure=fig)
    axes = [fig.add_subplot(gs[:, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 1])]

    # --formatting the different axes -- #
    axes[0].set_xlabel("X_1")
    axes[0].set_ylabel("X_2")
    axes[0].set_title("Decision boundary")
    axes[1].set_xlabel("i: iteration")
    axes[1].set_xlim(-10, 1000)
    axes[1].set_ylim(-10, 350)
    axes[1].set_ylabel(r"$\mathcal{L}_{\theta_0,\theta_1}$")
    axes[1].grid()
    axes[2].set_xlabel("i: iteration")
    axes[2].set_ylabel("Scores (metric_1 & metric_2)")
    axes[2].set_xlim(-10, 1000)
    axes[2].set_ylim(0.0,1.01)
    axes[2].grid()

    # -- Reading min and max values along X dimensions-- #
    X = data[:,0:2]
    X = X.astype(np.float64)
    Y = data[:,2].reshape(-1,1)

    idx = np.array([0])
    X_min, X_max = X[:,:2].min(axis=0), X[:,:2].max(axis=0)

    # -- Generate a grid of points with distance h between them -- #
    h = 0.01
    XX_1, XX_2 = np.meshgrid(np.arange(X_min[0], X_max[0], h),
                              np.arange(X_min[1], X_max[1], h))
    zeros_arr = np.zeros((XX_1.shape[0] * XX_1.shape[1], 1))
    XX = np.c_[XX_1.ravel(), XX_2.ravel(),
               zeros_arr.ravel(), zeros_arr.ravel(), zeros_arr.ravel()]

    # -- Predict the function value for the whole grid -- #
    preds = one_vs_all_prediction(clfs, XX)
    Z = preds.reshape(XX_1.shape)

    ## Initialisation of the PathCollection for the Axes[0] objects
    boundary = axes[0].contourf(XX_1, XX_2, Z, 3,
                                colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)

    lst_colors = np.array([dct_palet[house] for house in data[:,2]])
    raw_data = axes[0].scatter(X[:,0], X[:,1], c=lst_colors, edgecolor="k")

    ## Initialisation of the Line2D object for the Axes[1] objects
    cost_clf1 = clfs[0].cost()
    cost_clf2 = clfs[1].cost()
    cost_clf3 = clfs[2].cost()
    cost_clf4 = clfs[3].cost()
    l_cost_clf1, = axes[1].plot(idx, cost_clf1,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_cost_clf2, = axes[1].plot(idx, cost_clf2,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_cost_clf3, = axes[1].plot(idx, cost_clf3,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_cost_clf4, = axes[1].plot(idx, cost_clf4,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])

    ## Initialisation of the Line2D object for the Axes[2] objects
    met1_clf1 = clfs[0].dummy_metric1()
    met1_clf2 = clfs[1].dummy_metric1()
    met1_clf3 = clfs[2].dummy_metric1()
    met1_clf4 = clfs[3].dummy_metric1()
    met2_clf1 = clfs[0].dummy_metric2()
    met2_clf2 = clfs[1].dummy_metric2()
    met2_clf3 = clfs[2].dummy_metric2()
    met2_clf4 = clfs[3].dummy_metric2()
    l_met1_clf1, = axes[2].plot(idx, met1_clf1,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_met1_clf2, = axes[2].plot(idx, met1_clf2,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_met1_clf3, = axes[2].plot(idx, met1_clf3,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_met1_clf4, = axes[2].plot(idx, met1_clf4,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])
    l_met2_clf1, = axes[2].plot(idx, met2_clf1,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_met2_clf2, = axes[2].plot(idx, met2_clf2,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_met2_clf3, = axes[2].plot(idx, met2_clf3,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_met2_clf4, = axes[2].plot(idx, met2_clf4,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])

    fig.canvas.mpl_connect('close_event', f_close)
    anim_fig = FuncAnimation(fig, f_animate, fargs=(XX_1, XX_2, XX,), frames=int(1000/fps), repeat=False, cache_frame_data = False, blit=False)
    plt.waitforbuttonpress()

    return fig


def f_animate(i, XX_1, XX_2, XX):
    """
    ... Docstring ...
    """
    global clfs, idx, \
        cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
        met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
        met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
        boundary, axes, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
        l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4

    n_cycle = 100
    clfs[0].fit(n_cycle)
    clfs[1].fit(n_cycle)
    clfs[2].fit(n_cycle)
    clfs[3].fit(n_cycle)
    
    idx = np.concatenate((idx, np.array([i * n_cycle])))

    preds = one_vs_all_prediction(clfs, XX)
    Z = preds.reshape(XX_1.shape)

    cost_clf1 = np.concatenate((cost_clf1, clfs[0].cost()))
    cost_clf2 = np.concatenate((cost_clf2, clfs[1].cost()))
    cost_clf3 = np.concatenate((cost_clf3, clfs[2].cost()))
    cost_clf4 = np.concatenate((cost_clf4, clfs[3].cost()))

    tmp_met1_clf1 = clfs[0].dummy_metric1()
    tmp_met1_clf2 = clfs[1].dummy_metric1()
    tmp_met1_clf3 = clfs[2].dummy_metric1()
    tmp_met1_clf4 = clfs[3].dummy_metric1()
    tmp_met2_clf1 = clfs[0].dummy_metric2()
    tmp_met2_clf2 = clfs[1].dummy_metric2()
    tmp_met2_clf3 = clfs[2].dummy_metric2()
    tmp_met2_clf4 = clfs[3].dummy_metric2()

    met1_clf1 = np.concatenate((met1_clf1, tmp_met1_clf1))
    met1_clf2 = np.concatenate((met1_clf2, tmp_met1_clf2))
    met1_clf3 = np.concatenate((met1_clf3, tmp_met1_clf3))
    met1_clf4 = np.concatenate((met1_clf4, tmp_met1_clf4))
    met2_clf1 = np.concatenate((met2_clf1, tmp_met2_clf1))
    met2_clf2 = np.concatenate((met2_clf2, tmp_met2_clf2))
    met2_clf3 = np.concatenate((met2_clf3, tmp_met2_clf3))
    met2_clf4 = np.concatenate((met2_clf4, tmp_met2_clf4))

    # -- Plot the contour and training examples -- #

    # Update the plot objects: remove the previous collections to save memory.
    #l = len(boundary.collections)
    for coll in boundary.collections:
    # Remove the existing contours
        boundary.collections.remove(coll)

    boundary = axes[0].contourf(XX_1, XX_2, Z, 3, colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)

    l_cost_clf1.set_data(idx, cost_clf1)
    l_cost_clf2.set_data(idx, cost_clf2)
    l_cost_clf3.set_data(idx, cost_clf3)
    l_cost_clf4.set_data(idx, cost_clf4)

    l_met1_clf1.set_data(idx, met1_clf1)
    l_met1_clf2.set_data(idx, met1_clf2)
    l_met1_clf3.set_data(idx, met1_clf3)
    l_met1_clf4.set_data(idx, met1_clf4)
    l_met2_clf1.set_data(idx, met2_clf1)
    l_met2_clf2.set_data(idx, met2_clf2)
    l_met2_clf3.set_data(idx, met2_clf3)
    l_met2_clf4.set_data(idx, met2_clf4)

    return boundary.collections, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
            l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4


def f_close(event):
    """ Functions called when the graphical window is closed.
    It prints the last value of the theta vector and the last value of the
    cost function.
    """
    plt.close()

class DummyBinary():
    def __init__(self, house, theta0, theta1, alpha=1e-3):
        self.house = house
        self.theta0 = theta0
        self.theta1 = theta1
        self.alpha = alpha
        if self.house == "C1":
            self.border_x = 6
            self.border_y = 6
        if self.house == "C2":
            self.border_x = 6
            self.border_y = 13
        if self.house == "C3":
            self.border_x = 13
            self.border_y = 6
        if self.house == "C4":
            self.border_x = 13
            self.border_y = 13
    

    def fit(self, n_cycle:int):
        for _ in range(n_cycle):
            self.theta0 = self.theta0 + self.alpha * (self.border_x - self.theta0)
            self.theta1 = self.theta1 + self.alpha * (self.border_y - self.theta1)
    

    def cost(self) -> float:
        cost = (self.theta0 - self.border_x)**2 + (self.theta1 - self.border_y)**2
        return cost
    

    def predict(self, X:np.array) -> np.array:
        if self.house == 'C1':
            mask = (X[:,0] < self.theta0) & (X[:,1] < self.theta1)
        if self.house == 'C2':
            mask = (X[:,0] < self.theta0) & (X[:,1] > self.theta1)
        if self.house == 'C3':
            mask = (X[:,0] > self.theta0) & (X[:,1] < self.theta1)
        if self.house == 'C4':
            mask = (X[:,0] > self.theta0) & (X[:,1] > self.theta1)
        pred =np.zeros((X.shape[0], 1))
        pred[mask] = int(self.house[1])
        return pred


    def dummy_metric1(self):
        return np.array([0.5 * (self.theta0 / self.border_x + self.theta1 / self.border_y)])


    def dummy_metric2(self):
        return np.array([0.5 * ((self.theta0 / self.border_x)**2 + (self.theta1 / self.border_y)**2)])

# =========================================================================== #
# _________________________________  MAIN  __________________________________ #
# =========================================================================== #

if __name__ == "__main__":
    # -- Dummy data -- #
    x1 = np.random.randn(60,1) * 2.5 + 3.5
    x2 = np.random.randn(60,1) * 2.5 + 3.5
    x3 = np.random.randn(60,1) * 2.5 + 15.5
    x4 = np.random.randn(60,1) * 2.5 + 15.5
    stud_house = 60 * ['C1'] + 60 * ['C2'] + 60 * ['C3'] + 60 * ['C4']
    c_house = [dct_palet[house] for house in stud_house]

    y1 = np.random.randn(60,1) * 2.5 + 3.5
    y2 = np.random.randn(60,1) * 2.5 + 15.5
    y3 = np.random.randn(60,1) * 2.5 + 3.5
    y4 = np.random.randn(60,1) * 2.5 + 15.5
    
    X = np.concatenate((x1, x2, x3, x4)) # shape: (240,1)
    Y = np.concatenate((y1, y2, y3, y4)) # shape: (240,1)
    data = np.concatenate((X, Y, np.array(stud_house).reshape(-1,1)), axis=1)  # shape: (240,3)

    clf1 = DummyBinary("C1", np.random.rand(1), np.random.rand(1))
    clf2 = DummyBinary("C2", np.random.rand(1), np.random.rand(1))
    clf3 = DummyBinary("C3", np.random.rand(1), np.random.rand(1))
    clf4 = DummyBinary("C4", np.random.rand(1), np.random.rand(1))
    clfs = [clf1, clf2, clf3, clf4]

    ## Visualize the raw dummy data.
    #plt.scatter(X, Y, c=c_house, s=5)
    #plt.show()

    do_animation(clfs, data)

The Class DummyBinary mimics in a simplified way, what my One-vs-All class can do. You can see a bunch of global in anim_visu and f_anim, in this way the code "works", but I'm aware there is something very wrong.

Attempts:

  1. No global variables, everything were passed to f_anim via fargs, but when returning from f_anim, all the modification of the variables in f_anim scope were lost (normal behavior obviously),
  2. Moving the definition of the f_anim within the body of anim_visu, to make f_anim an inner_function. I'm not experienced enough, so I did not succeed to make it works this way, I noticed that It may appeared it is not possible to modify variable declared in the anim_visu scope in the inner function.
  3. Declare all the variables I need as global, it work in a way, but as you can see by running the code (in the axes[0]), the PathCollections are not cleared/deleted (despite the loop with boundary.collections.remove(coll)) and the number of PathCollection in the axes[0] seems to increased, leading to a drop of the speed the frames are updated.

Looking forward for your advice (and solution+explanation I hope). And thank you for your times and neurons.

MD4
  • 19
  • 1
  • Does it improve if I set `axes[0].clear()` at the beginning of the animation function as shown in [this answer](https://stackoverflow.com/questions/23070305/how-can-i-make-an-animation-with-contourf/38401705#38401705)? Issue-specific questions will be answered faster and better. – r-beginners Sep 04 '21 at 13:59
  • Indeed, adding the line ```axes[0].clear()``` solve the problem of accumulation of PathCollection object in ```axes[0]```. Thank you for your answer. Did you know what is the problem in my implementation ? And did you have a better implementation than mine to avoid all the global variables ? – MD4 Sep 04 '21 at 14:10

0 Answers0