I am having trouble plotting a grid of images in dash. In particular, I want to create a grid of n_rows x n_cols, where n_cols can be anywhere between 5 and 10 and n_rows need to be user specified. Every image is the same size.
So far, I have used px.imshow() with the facet_col option. However, for a large number of images the images get very small.
I have used plotly.subplots.make_subplots() directly, however, when I change n_rows within the app the images do not stay in their respective positions and get moved around. In particular, when I first choose to plot 5 images beneath each other, the plot will look fine, plotting five images closely together (chosen first). However, when the user subsequently chooses to plot 2 points, the images will be plotted in the same space the five images were. Similarly, when the user first chooses to plot 2 images the image looks fine. Subsequently, when the user plots 5 images, they all get messed up.
See the following code to reproduce:
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import dash.exceptions
import numpy as np
from plotly.subplots import make_subplots
def create_fig(n_rows):
fig = make_subplots(n_rows, 1)
for i in range(n_rows):
fig.add_trace(go.Image(z=np.random.randint(0, 255, size=(60, 60, 3))), i+1, j)
fig.update_layout(autosize=True,
width=100,
height=n_rows*100,
margin=dict(l=0, r=0, b=0, t=0)
)
return fig
def create_fig_using_pximshow(n_rows):
images = np.random.randint(0, 255, size=(n_rows, 60, 60, 3))
fig = px.imshow(images, facet_col=0, facet_col_wrap=1, facet_row_spacing=0.3/n_rows)
fig.update_layout(autosize=True,
width=100,
height=n_rows*100,
margin=dict(l=0, r=0, b=0, t=0)
)
return fig
app = Dash(__name__)
app.layout = html.Div(
[dcc.Graph(id='graph', style={'overflow':'scroll'}),
dcc.Input(value=0, id='input', type='number')])
@app.callback(Output('graph', 'figure'), Input('input', 'value'))
def create_graph(n_rows):
if n_rows == 0:
raise dash.exceptions.PreventUpdate("Prevent update")
else:
if n_rows is not None:
return create_fig(int(n_rows))
# return create_fig_using_pximshow(int(n_rows))
if __name__ == '__main__':
app.run_server(debug=True)
Note that changing go.Image to go.Scatter makes the code work and the dashboard graph adapt to the amount of rows present. Is there a better/easier way to plot a grid of images?