I'm struggling with FuncAnimation in matplotlib.animation, and I could not find out any examples or post looking similar to my problem (I mean, yes there is post concerning contourf used in funcAnimation but in those posts they succeed to delete the PathCollection object but in my case something is not working).
Context:
In a school project concerning One-vs-All notion (multiple binary classifiers), I want to implement functions to animate a figure having 3 Axes and containing multiple Line2D objects, an PathCollection object from scatter method and a QuadContourSet from contourf method.
Here a screen of how it looks like (obtained when I plot the data at the end of the training of the One-vs-All):
Representation of the static graph
Legend:
- Left: Boundary decision in (Herbology)-(Defense against Dark Arts) plane,
- Top right: Loss function of each binary classifiers,
- Bottom right: Precision and Recall metrics of each classifiers.
Methods:
I am trying to have a animated version of the plot using FuncAnimation from matplotlib.Animation module. Animated version of the plot is a bonus feature of my project, then the animation part/core is made in functions, you can see a simplification below (a bare bones representation) :
def anim_visu(models, data):
# initialization of the figure and object representing the data
...
def f_anim():
# Function which update the data at each frames
visu = FuncAnimation(fig, f_anim, ...)
return fig
[...]
if __name__ == "__main__":
[...]
if bool_dynamic: # activation of the dynamic visualization
anim_visu(models, data)
And here a minimal workish example:
# =========================================================================== #
# |Importation des lib/packages| #
# =========================================================================== #
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
dct_palet = {"C1":"dodgerblue",
"C2":"red",
"C3":"green",
"C4":"goldenrod"}
fps = 15
# =========================================================================== #
# | Definition des fonctions | #
# =========================================================================== #
def one_vs_all_prediction(classifiers:list, X:np.array) -> np.array:
"""
... Docstring ...
"""
preds = np.zeros((X.shape[0],1))
for clf in classifiers:
tmp = clf.predict(X)
mask = preds == 0
preds[mask] = tmp[mask]
return preds
def one_vs_all_class_onehot(class_pred:np.array):
"""
... Docstring ...
"""
house = {"C1":1., "C2":2., "C3":3., "C4":4.}
onehot_pred = np.chararray((class_pred.shape[0],1), itemsize=2)
for key, item in house.items():
mask = class_pred == item
onehot_pred[mask] = key
return onehot_pred
def do_animation(clfs:list, data:np.ndarray):
""" Core function for the animated vizualisation.
The function defines all the x/y_labels, the titles.
"""
global idx, cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
boundary, axes, \
l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4
plt.style.use('seaborn-pastel')
# -- Declaring the figure and the axes -- #
fig = plt.figure(figsize=(15,9.5))
gs = GridSpec(2, 2, figure=fig)
axes = [fig.add_subplot(gs[:, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 1])]
# --formatting the different axes -- #
axes[0].set_xlabel("X_1")
axes[0].set_ylabel("X_2")
axes[0].set_title("Decision boundary")
axes[1].set_xlabel("i: iteration")
axes[1].set_xlim(-10, 1000)
axes[1].set_ylim(-10, 350)
axes[1].set_ylabel(r"$\mathcal{L}_{\theta_0,\theta_1}$")
axes[1].grid()
axes[2].set_xlabel("i: iteration")
axes[2].set_ylabel("Scores (metric_1 & metric_2)")
axes[2].set_xlim(-10, 1000)
axes[2].set_ylim(0.0,1.01)
axes[2].grid()
# -- Reading min and max values along X dimensions-- #
X = data[:,0:2]
X = X.astype(np.float64)
Y = data[:,2].reshape(-1,1)
idx = np.array([0])
X_min, X_max = X[:,:2].min(axis=0), X[:,:2].max(axis=0)
# -- Generate a grid of points with distance h between them -- #
h = 0.01
XX_1, XX_2 = np.meshgrid(np.arange(X_min[0], X_max[0], h),
np.arange(X_min[1], X_max[1], h))
zeros_arr = np.zeros((XX_1.shape[0] * XX_1.shape[1], 1))
XX = np.c_[XX_1.ravel(), XX_2.ravel(),
zeros_arr.ravel(), zeros_arr.ravel(), zeros_arr.ravel()]
# -- Predict the function value for the whole grid -- #
preds = one_vs_all_prediction(clfs, XX)
Z = preds.reshape(XX_1.shape)
## Initialisation of the PathCollection for the Axes[0] objects
boundary = axes[0].contourf(XX_1, XX_2, Z, 3,
colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)
lst_colors = np.array([dct_palet[house] for house in data[:,2]])
raw_data = axes[0].scatter(X[:,0], X[:,1], c=lst_colors, edgecolor="k")
## Initialisation of the Line2D object for the Axes[1] objects
cost_clf1 = clfs[0].cost()
cost_clf2 = clfs[1].cost()
cost_clf3 = clfs[2].cost()
cost_clf4 = clfs[3].cost()
l_cost_clf1, = axes[1].plot(idx, cost_clf1,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
l_cost_clf2, = axes[1].plot(idx, cost_clf2,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
l_cost_clf3, = axes[1].plot(idx, cost_clf3,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
l_cost_clf4, = axes[1].plot(idx, cost_clf4,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])
## Initialisation of the Line2D object for the Axes[2] objects
met1_clf1 = clfs[0].dummy_metric1()
met1_clf2 = clfs[1].dummy_metric1()
met1_clf3 = clfs[2].dummy_metric1()
met1_clf4 = clfs[3].dummy_metric1()
met2_clf1 = clfs[0].dummy_metric2()
met2_clf2 = clfs[1].dummy_metric2()
met2_clf3 = clfs[2].dummy_metric2()
met2_clf4 = clfs[3].dummy_metric2()
l_met1_clf1, = axes[2].plot(idx, met1_clf1,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
l_met1_clf2, = axes[2].plot(idx, met1_clf2,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
l_met1_clf3, = axes[2].plot(idx, met1_clf3,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
l_met1_clf4, = axes[2].plot(idx, met1_clf4,
ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])
l_met2_clf1, = axes[2].plot(idx, met2_clf1,
ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
l_met2_clf2, = axes[2].plot(idx, met2_clf2,
ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
l_met2_clf3, = axes[2].plot(idx, met2_clf3,
ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
l_met2_clf4, = axes[2].plot(idx, met2_clf4,
ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])
fig.canvas.mpl_connect('close_event', f_close)
anim_fig = FuncAnimation(fig, f_animate, fargs=(XX_1, XX_2, XX,), frames=int(1000/fps), repeat=False, cache_frame_data = False, blit=False)
plt.waitforbuttonpress()
return fig
def f_animate(i, XX_1, XX_2, XX):
"""
... Docstring ...
"""
global clfs, idx, \
cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
boundary, axes, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4
n_cycle = 100
clfs[0].fit(n_cycle)
clfs[1].fit(n_cycle)
clfs[2].fit(n_cycle)
clfs[3].fit(n_cycle)
idx = np.concatenate((idx, np.array([i * n_cycle])))
preds = one_vs_all_prediction(clfs, XX)
Z = preds.reshape(XX_1.shape)
cost_clf1 = np.concatenate((cost_clf1, clfs[0].cost()))
cost_clf2 = np.concatenate((cost_clf2, clfs[1].cost()))
cost_clf3 = np.concatenate((cost_clf3, clfs[2].cost()))
cost_clf4 = np.concatenate((cost_clf4, clfs[3].cost()))
tmp_met1_clf1 = clfs[0].dummy_metric1()
tmp_met1_clf2 = clfs[1].dummy_metric1()
tmp_met1_clf3 = clfs[2].dummy_metric1()
tmp_met1_clf4 = clfs[3].dummy_metric1()
tmp_met2_clf1 = clfs[0].dummy_metric2()
tmp_met2_clf2 = clfs[1].dummy_metric2()
tmp_met2_clf3 = clfs[2].dummy_metric2()
tmp_met2_clf4 = clfs[3].dummy_metric2()
met1_clf1 = np.concatenate((met1_clf1, tmp_met1_clf1))
met1_clf2 = np.concatenate((met1_clf2, tmp_met1_clf2))
met1_clf3 = np.concatenate((met1_clf3, tmp_met1_clf3))
met1_clf4 = np.concatenate((met1_clf4, tmp_met1_clf4))
met2_clf1 = np.concatenate((met2_clf1, tmp_met2_clf1))
met2_clf2 = np.concatenate((met2_clf2, tmp_met2_clf2))
met2_clf3 = np.concatenate((met2_clf3, tmp_met2_clf3))
met2_clf4 = np.concatenate((met2_clf4, tmp_met2_clf4))
# -- Plot the contour and training examples -- #
# Update the plot objects: remove the previous collections to save memory.
#l = len(boundary.collections)
for coll in boundary.collections:
# Remove the existing contours
boundary.collections.remove(coll)
boundary = axes[0].contourf(XX_1, XX_2, Z, 3, colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)
l_cost_clf1.set_data(idx, cost_clf1)
l_cost_clf2.set_data(idx, cost_clf2)
l_cost_clf3.set_data(idx, cost_clf3)
l_cost_clf4.set_data(idx, cost_clf4)
l_met1_clf1.set_data(idx, met1_clf1)
l_met1_clf2.set_data(idx, met1_clf2)
l_met1_clf3.set_data(idx, met1_clf3)
l_met1_clf4.set_data(idx, met1_clf4)
l_met2_clf1.set_data(idx, met2_clf1)
l_met2_clf2.set_data(idx, met2_clf2)
l_met2_clf3.set_data(idx, met2_clf3)
l_met2_clf4.set_data(idx, met2_clf4)
return boundary.collections, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4
def f_close(event):
""" Functions called when the graphical window is closed.
It prints the last value of the theta vector and the last value of the
cost function.
"""
plt.close()
class DummyBinary():
def __init__(self, house, theta0, theta1, alpha=1e-3):
self.house = house
self.theta0 = theta0
self.theta1 = theta1
self.alpha = alpha
if self.house == "C1":
self.border_x = 6
self.border_y = 6
if self.house == "C2":
self.border_x = 6
self.border_y = 13
if self.house == "C3":
self.border_x = 13
self.border_y = 6
if self.house == "C4":
self.border_x = 13
self.border_y = 13
def fit(self, n_cycle:int):
for _ in range(n_cycle):
self.theta0 = self.theta0 + self.alpha * (self.border_x - self.theta0)
self.theta1 = self.theta1 + self.alpha * (self.border_y - self.theta1)
def cost(self) -> float:
cost = (self.theta0 - self.border_x)**2 + (self.theta1 - self.border_y)**2
return cost
def predict(self, X:np.array) -> np.array:
if self.house == 'C1':
mask = (X[:,0] < self.theta0) & (X[:,1] < self.theta1)
if self.house == 'C2':
mask = (X[:,0] < self.theta0) & (X[:,1] > self.theta1)
if self.house == 'C3':
mask = (X[:,0] > self.theta0) & (X[:,1] < self.theta1)
if self.house == 'C4':
mask = (X[:,0] > self.theta0) & (X[:,1] > self.theta1)
pred =np.zeros((X.shape[0], 1))
pred[mask] = int(self.house[1])
return pred
def dummy_metric1(self):
return np.array([0.5 * (self.theta0 / self.border_x + self.theta1 / self.border_y)])
def dummy_metric2(self):
return np.array([0.5 * ((self.theta0 / self.border_x)**2 + (self.theta1 / self.border_y)**2)])
# =========================================================================== #
# _________________________________ MAIN __________________________________ #
# =========================================================================== #
if __name__ == "__main__":
# -- Dummy data -- #
x1 = np.random.randn(60,1) * 2.5 + 3.5
x2 = np.random.randn(60,1) * 2.5 + 3.5
x3 = np.random.randn(60,1) * 2.5 + 15.5
x4 = np.random.randn(60,1) * 2.5 + 15.5
stud_house = 60 * ['C1'] + 60 * ['C2'] + 60 * ['C3'] + 60 * ['C4']
c_house = [dct_palet[house] for house in stud_house]
y1 = np.random.randn(60,1) * 2.5 + 3.5
y2 = np.random.randn(60,1) * 2.5 + 15.5
y3 = np.random.randn(60,1) * 2.5 + 3.5
y4 = np.random.randn(60,1) * 2.5 + 15.5
X = np.concatenate((x1, x2, x3, x4)) # shape: (240,1)
Y = np.concatenate((y1, y2, y3, y4)) # shape: (240,1)
data = np.concatenate((X, Y, np.array(stud_house).reshape(-1,1)), axis=1) # shape: (240,3)
clf1 = DummyBinary("C1", np.random.rand(1), np.random.rand(1))
clf2 = DummyBinary("C2", np.random.rand(1), np.random.rand(1))
clf3 = DummyBinary("C3", np.random.rand(1), np.random.rand(1))
clf4 = DummyBinary("C4", np.random.rand(1), np.random.rand(1))
clfs = [clf1, clf2, clf3, clf4]
## Visualize the raw dummy data.
#plt.scatter(X, Y, c=c_house, s=5)
#plt.show()
do_animation(clfs, data)
The Class DummyBinary mimics in a simplified way, what my One-vs-All class can do. You can see a bunch of global in anim_visu and f_anim, in this way the code "works", but I'm aware there is something very wrong.
Attempts:
- No global variables, everything were passed to
f_anim
viafargs
, but when returning fromf_anim
, all the modification of the variables inf_anim
scope were lost (normal behavior obviously), - Moving the definition of the
f_anim
within the body ofanim_visu
, to makef_anim
an inner_function. I'm not experienced enough, so I did not succeed to make it works this way, I noticed that It may appeared it is not possible to modify variable declared in theanim_visu
scope in the inner function. - Declare all the variables I need as global, it work in a way, but as you can see by running the code (in the axes[0]), the PathCollections are not cleared/deleted (despite the loop with
boundary.collections.remove(coll)
) and the number of PathCollection in the axes[0] seems to increased, leading to a drop of the speed the frames are updated.
Looking forward for your advice (and solution+explanation I hope). And thank you for your times and neurons.