6

I am trying to create a figure showing image "reconstruction" as function of number of PCs. I want to animate this to show the original image, the cumulative image (over PCs 1,...,i) and the parts that still remain to be "reconstructed". Together with that I want to show the distance between the original and reconstructed image as a function of the number of PCs.

I managed to create the figure below, which animates the scatter plot at the bottom and also the images at the top.

enter image description here

The problem is that once the animation begins the two images on the right "disappear" and I think they appear under the "Original Image"

enter image description here

This is the code I have (creation of animation frames with all 3 images and scatters, and then formation of figure):

import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA

pio.templates["custom"] = go.layout.Template(
    layout=go.Layout(
        margin=dict(l=20, r=20, t=40, b=0)
    )
)
pio.templates.default = "simple_white+custom"


class AnimationButtons():
    def play_scatter(frame_duration = 500, transition_duration = 300):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": False},
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
    
    def play(frame_duration = 1000, transition_duration = 0):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": True},
                            "mode":"immediate",
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
    
    def pause():
        return dict(label="Pause", method="animate", args=
                    [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T


reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0],
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()

I tried finding similar questions but wasn't able to find anything that would work for the showing of both px.imshow and go.Scatter with subplots and animation.

The data X are the MNIST digits images after centering. Here is a numpy array with one such image: (X.shape=(16,5,5) - 16 images of 5x5 - animation only on first image)

X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]]] )

Placed the above code in a Jupyter notebook on GitHub

Gilad Green
  • 36,708
  • 7
  • 61
  • 95
  • If you take the time to [share a sample dataset](https://stackoverflow.com/questions/63163251/pandas-how-to-easily-share-a-sample-dataframe-using-df-to-dict/63163254#63163254), or a dataset that has a resembling structre of your dataset, then I'm pretty sure you'll get the needed assistance. – vestland Feb 05 '21 at 11:33
  • @vestland - You are right! Added the data used in the snippet – Gilad Green Feb 05 '21 at 11:44
  • for me `X = np.array...` does not work, maybe missing commas? – jayveesea Feb 05 '21 at 21:01
  • maybe `np.array2string(X, separator=',')` will help. – jayveesea Feb 05 '21 at 21:15
  • I think there's still some stuff missing here, like the `AnimationButtons` and how the animation happens. Also, I'm assuming you're using `from sklearn.decomposition import PCA` but with that I needed `n_components=1`. – jayveesea Feb 07 '21 at 12:55
  • I'm getting `ValueError: n_components=50 must be between 0 and min(n_samples, n_features)=1 with svd_solver='full'` ...but if I switch to `n_components=1` I get errors further down. Using scikit-learn-0.24.1. – jayveesea Feb 08 '21 at 13:05
  • @jayveesea - ya.. yo are right. The problem was that the fitting is on a list of images but the animation, where the actual problem is, is on a single image. Added data for enough images such that there are more than one component -> so there should be some animation – Gilad Green Feb 08 '21 at 13:57
  • @GiladGreen Your code does noe reproduce any images on my end. I would gladly have taken a closer look at this if it did. – vestland Feb 10 '21 at 12:25
  • @vestland - If I take the code above and run it (tried it now just to make sure) then I get what is seen above, with the one difference that the image is not the entire 28*28 image showing the entire zero digit but only the top left 5*5 part of the image (so that I can paste all the data `X` in the question above) Once I run it then I indeed see the animation problem described in the 2 pictures above – Gilad Green Feb 10 '21 at 14:24
  • @GiladGreen I'm only getting a red line for the bottom figure. No images. Not sure why though. – vestland Feb 10 '21 at 14:28
  • 1
    @vestland - here is a naive Git repository with the code I'm running for the example above: https://github.com/GreenGilad/Stackoverflow.git – Gilad Green Feb 10 '21 at 14:32
  • @GiladGreen OK! We'll just have to use the answer section for a bit of communication and screenshots. Can't guarantee that it will turn into an actual answer though... – vestland Feb 10 '21 at 14:40
  • @vestland - Cool!! :) Updated the github to have images of like 16*16 instead of only 5*5 – Gilad Green Feb 10 '21 at 14:44

4 Answers4

3

Similar to what jayvessea suggested, I ended up playing with the structure of the px.imshow. I first created the px.imshow with both facets and animation, and then added to it both the scatter plot and the desired layout

pca = PCA(n_components=50).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[150], pca.transform(X[150].reshape(-1, 1)).T

reconstructed, distortion, images, scatters, titles = np.zeros_like(X[0]), [], [], [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every other reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        images.append([img.copy(), reconstructed.copy(), (img - reconstructed).copy()])
        scatters.append(go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion, name=3, xaxis="x4", yaxis="y4", marker_color="black"))
        titles.append(rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")


        
# Create figure on the basis of the animated facetted imshow figure
fig = px.imshow(np.array(images), facet_col=1, animation_frame=0, binary_string=True)
for i, (scatter, title) in enumerate(zip(*[scatters, titles])):
    fig["frames"][i]["data"] += (scatter, )
    fig["frames"][i]["traces"] = [0,1,2,3]
    fig["frames"][i]["layout"]["title"] = title 
fig.add_traces(data=fig["frames"][0]["data"][-1])

# Create "template" figure to transfer layout onto the `fig` figure
layout = make_subplots(rows=2, cols=3, 
                       subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                       specs=[[{"type":"Image"}, {"type":"Image"}, {"type":"Image"}], [{"type":"xy","colspan": 3}, None, None]], row_heights=[500, 200],)

layout.update_layout(title=titles[0],
                     xaxis4=dict(range=[0, 50], autorange=False),
                     yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                     margin = dict(t = 100), width=800,
                     updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])

fig["layout"] = layout["layout"]
fig

It is not a very elegant solution but it is a sufficient workaround.

enter image description here

Gilad Green
  • 36,708
  • 7
  • 61
  • 95
  • @jayveesea now the next challenge is to manage to export this as an animated gif haha – Gilad Green Feb 17 '21 at 15:35
  • @GiladGreen I absolutely agree with jayveesa. This turned into a great post all over. I'm sorry that my feeble attempt to contribute ended... peculiarly =) – vestland Feb 17 '21 at 21:32
2

While this is not a complete solution it may help get there...

Using animation_frame and facet_col you can build the upper part of the figure using facets. Unfortunately I'm not sure how to link this to an animated scatter plot. You could create scatter images and then tie them into this, but then you loose the ability to hover in the scatter and get info.

But, this may be of some value if you inspect the output print(fig0), and compare it to yours print(fig).

# X = see above

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T
ORIG, DIFF, RECN, = [],[],[]

reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
      ORIG = np.append(ORIG,img)
      DIFF = np.append(DIFF,(img - reconstructed).copy())
      RECN = np.append(RECN,reconstructed.copy())

DATA = np.array([np.reshape(ORIG,(8,16,16)),
               np.reshape(DIFF,(8,16,16)),
               np.reshape(RECN,(8,16,16))])

fig0 = px.imshow(DATA, animation_frame=1, facet_col=0, binary_string=True)
fig0.show()

print(fig0) #inspect the layout

enter image description here

jayveesea
  • 2,886
  • 13
  • 25
1

Answer in progress...


This is what your github code produces on my end:

Initially...

enter image description here

And after animation ends:

enter image description here

vestland
  • 55,229
  • 37
  • 187
  • 305
  • Oh wow this is very strange.. hm.. maybe something with versions of different packages.. Just added on GitHub the conda environment yml I am using – Gilad Green Feb 10 '21 at 14:45
  • Did you see the env file I uploaded? – Gilad Green Feb 11 '21 at 08:46
  • @GiladGreen Yes, I did. Not sure what I'm missing. Are you absolutely sure that the exact code and exact data you've provided does not produce the same thing that my image is showing? Sometimes things can get lost when you try to put together a minimal reproducible example... – vestland Feb 11 '21 at 08:50
  • ya.. to verify I even downloaded it once more from Git and ran it with the environment file – Gilad Green Feb 11 '21 at 08:54
  • @GiladGreen OK, then it seems I'm stuck here. Sorry about that. – vestland Feb 11 '21 at 09:41
  • It's okay :) in any case thanks for trying!! – Gilad Green Feb 11 '21 at 10:44
  • just fyi, I was able to reproduce on colab with one addition: at the very beginning `!pip install plotly==4.14.3` (may need to reset kernel too). – jayveesea Feb 11 '21 at 12:10
  • @GiladGreen What version were you on again? – vestland Feb 11 '21 at 12:17
  • 1
    @vestland plotly 4.14.3 – Gilad Green Feb 11 '21 at 15:19
  • @GiladGreen This really is a strange one. I just updated to `4.14.3`, and still no images. *But*, now the frames for the images disappear during animatin. That did not happen for my earlier plotly installation. I'll dig a bit more... – vestland Feb 11 '21 at 20:58
  • 1
    @vestland - at the end I went with some workaround using the `px.imshow` as my animated "template". Not a very elegant solution but at least it works.. :) – Gilad Green Feb 17 '21 at 15:08
  • 1
    @GiladGreen Glad to see you found something that worked for you! – vestland Feb 17 '21 at 15:10
0

TL;DR: px.imshow automatically sets the x and y axes to be the first set of axes, since it returns a whole figure; you either need to use go.Image (last codeblock) or remove the xaxis and yaxis attributes of the data.

One of the neat things about plotly is that every object is a dict at heart, which you can see when you print the object to the terminal. When you do px.imshow, the resulting figure has the following info:

In [1]: imfig = px.imshow(image, binary_string=True)

In [2]: print(imfig)
Figure({
    'data': [{'hovertemplate': 'x: %{x}<br>y: %{y}<extra></extra>',
              'name': '0',
              'source': ('' ... 'cPdQY3ZVUt1YkAAAAASUVORK5CYII='),
              'type': 'image',
              'xaxis': 'x',
              'yaxis': 'y'}],
    'layout': {'template': '...',
               'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0]},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0]}}
})

Aha! In imfig.data[0], the image has xaxis and yaxis specified! This will always put the image in the first subplot. (The other subplots have axes ('x1','y1'), ('x2','y2'), etc. by default.) Indeed, in the frames list, every image has 'xaxis':'x' and 'yaxis':y, meaning that each frame will place the images in the first subplot.

Knowing that, the fix is easy: just set the x and y axis to the proper values! The new loop would look like this:

for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))   

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        imagedata = px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0]]
        for image,axis in zip(imagedata,['1','2','3']):
            image.xaxis = 'x' + axis #set the axes so that each image is 
            image.yaxis = 'y' + axis # displayed in the correct subplot
        frames.append(go.Frame(
            data = imagedata + 
                    [go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))

In fact, there's an even easier way. When plotly updates the figure using the frame, it matches each data element with the corresponding data in the current figure and then updates the data accordingly. This means if you don't specify a parameter in a frame.data dict, it will just use the default from the initial layout, which, in the case of xaxis, and yaxis, are actually set by add_traces (since you specify the rows and columns, the xaxis and yaxis properties of the frame[0].data objects are overwritten). So you can achieve the same effect without having to specify the location by just removing xaxis and yaxis altogether:

for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))   

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        imagedata = px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0]]
        for image in imagedata:
            image.xaxis = None #fields with value None are treated as 
            image.yaxis = None # deleted when converted to dict/json
        frames.append(go.Frame(
            data = imagedata + 
                    [go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))

Finally, if you want to go all the way into the weeds and skip the px.imshow altogether, you can specify the image data yourself, and change only the source. I think the minimum data you need to show the image is 'type':'image' and 'source':[source] - or, in other words, just go.Image(source=source); from the images tutorial:

The source attribute of a go.layout.Image can be the URL of an image, or a PIL Image object (from PIL import Image; img = Image.open('filename.png'))

So, the simplest solution is:

from PIL import Image
img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T

reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [go.Image(Image.fromarray(img)),
                    go.Image(Image.fromarray(img - reconstructed)), 
                    go.Image(Image.fromarray(reconstructed)),
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()

--EDIT--

For some reason, it seems plotly has either removed or broken PIL image interoperability for go.Image; I can't get any of the examples on the documentation or online to work. the px.imshow way works, though, so I think the most reliable way to get the image data out of px.imshow would be to only use the source:

img_data = px.imshow(image,binary_string=True).data[0].source
image = go.Image(source=img_data)
frame.data = [image]
minerharry
  • 101
  • 8
  • Hold on, the PIL thing seems to be breaking on .fromarray? currently investigating – minerharry Jun 02 '23 at 19:55
  • plotly doesn't seem to like it unless it's a raw string like output by px.imshow. not sure why. has worked for other people. trimming imshow still works – minerharry Jun 02 '23 at 20:03