I am trying to create a figure showing image "reconstruction" as function of number of PCs. I want to animate this to show the original image, the cumulative image (over PCs 1,...,i) and the parts that still remain to be "reconstructed". Together with that I want to show the distance between the original and reconstructed image as a function of the number of PCs.
I managed to create the figure below, which animates the scatter plot at the bottom and also the images at the top.
The problem is that once the animation begins the two images on the right "disappear" and I think they appear under the "Original Image"
This is the code I have (creation of animation frames with all 3 images and scatters, and then formation of figure):
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA
pio.templates["custom"] = go.layout.Template(
layout=go.Layout(
margin=dict(l=20, r=20, t=40, b=0)
)
)
pio.templates.default = "simple_white+custom"
class AnimationButtons():
def play_scatter(frame_duration = 500, transition_duration = 300):
return dict(label="Play", method="animate", args=
[None, {"frame": {"duration": frame_duration, "redraw": False},
"fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
def play(frame_duration = 1000, transition_duration = 0):
return dict(label="Play", method="animate", args=
[None, {"frame": {"duration": frame_duration, "redraw": True},
"mode":"immediate",
"fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
def pause():
return dict(label="Pause", method="animate", args=
[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])
pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))
img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T
reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
# Reconstruct image using the first i principal components
reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
distortion.append(np.sum((img - reconstructed) ** 2))
# Append animation frame every 5'th reconstruction
if i % 2 == 0 or i == pca.n_components_-1:
frames.append(go.Frame(
data = [px.imshow(img, binary_string=True).data[0],
px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
px.imshow(reconstructed.copy(), binary_string=True).data[0],
go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
traces = [0, 1, 2, 3],
layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))
fig = make_subplots(rows=2, cols=3,
subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)
fig.update_layout(title=frames[0]["layout"]["title"],
xaxis4=dict(range=[0, 50], autorange=False),
yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
margin = dict(t = 100),
width=800,
updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()
I tried finding similar questions but wasn't able to find anything that would work for the showing of both px.imshow
and go.Scatter
with subplots and animation.
The data X
are the MNIST digits images after centering. Here is a numpy array with one such image: (X.shape=(16,5,5)
- 16 images of 5x5 - animation only on first image)
X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]]] )
Placed the above code in a Jupyter notebook on GitHub