4

Unfortunately it is not possible to create live plots in a google colab notebook using %matplotlib notebook like it is in a offline jupyter notebook on my PC.

I found two similar questions answering how to achieve this for plotly plots (link_1, link_2). However I cannot manage to adapt it to matplotlib or do not know if that is possible at all.

I am following code from this tutorial here: GitHub link. In particular I would like to run this code, which creates a callback plotting the reward per step over the training steps:

import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook


class PlottingCallback(BaseCallback):
    """
    Callback for plotting the performance in realtime.

    :param verbose: (int)
    """
    def __init__(self, verbose=1):
        super(PlottingCallback, self).__init__(verbose)
        self._plot = None

    def _on_step(self) -> bool:
        # get the monitor's data
        x, y = ts2xy(load_results(log_dir), 'timesteps')
      if self._plot is None: # make the plot
          plt.ion()
          fig = plt.figure(figsize=(6,3))
          ax = fig.add_subplot(111)
          line, = ax.plot(x, y)
          self._plot = (line, ax, fig)
          plt.show()
      else: # update and rescale the plot
          self._plot[0].set_data(x, y)
          self._plot[-2].relim()
          self._plot[-2].set_xlim([self.locals["total_timesteps"] * -0.02, 
                                   self.locals["total_timesteps"] * 1.02])
          self._plot[-2].autoscale_view(True,True,True)
          self._plot[-1].canvas.draw()

# Create log dir
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = make_vec_env('MountainCarContinuous-v0', n_envs=1, monitor_dir=log_dir)

plotting_callback = PlottingCallback()

model = PPO2('MlpPolicy', env, verbose=0)
model.learn(20000, callback=plotting_callback)
Philipp
  • 652
  • 2
  • 10
  • 28

1 Answers1

1

A hack that you can use, is use the same code that you would use on a jupyter notbook, create a button, and use JavaScript to click the button, fooling the frontend to request an update, so that it keeps updating the values.

Here it is an example that uses ipywidgets.

from IPython.display import display
import ipywidgets
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.2)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="This button does nothing... except send a\
 socket request to google servers to receive updated information since the \
 frontend wants to change..")

display(button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button').click()},
    1000);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
))

Dealing specifically with matplotlib is a bit more complicated, I thought I could simply call matplotlib plot on the asyncio function, but it really lags down the updates because it seems to do unnecessary rendering in the background where no one sees the plot. So another workaround is to update the plot on the code of the button update. This code is also inspired by Add points to matlibplot scatter plot live and Matplotlib graphic image to base64 The reason being that it is unnecessary to create a plot figure for every plot, you can just modify the figure you already had. This of course means more code.

from IPython.display import display
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
def pltToImg(plt):
 s = io.BytesIO()
 plt.savefig(s, format='png', bbox_inches="tight")
 s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
 #plt.close()
 return '<img align="left" src="data:image/png;base64,%s">' % s
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.5)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="Update =D")
a=ipywidgets.HTML(
    value="image here"
)
output = ipywidgets.Output()
plt.ion()
fig, ax = plt.subplots()
plot = ax.scatter([], [])
point = np.random.normal(0, 1, 2)
array = plot.get_offsets()
array = np.append(array, [point], axis=0)
plot.set_offsets(array)
plt.close()
ii=0
def on_button_clicked(b):
       global ii
       ii+=1
       point=np.r_[ii,np.random.normal(0, 1, 1)]
       array = plot.get_offsets()
       array = np.append(array, [point], axis=0)
       plot.set_offsets(array)
       ax.set_xlim(array[:, 0].min() - 0.5, array[:,0].max() + 0.5)
       ax.set_ylim(array[:, 1].min() - 0.5, array[:, 1].max() + 0.5)
       a.value=(pltToImg(fig))
       a.value+=str(progress.value)
       a.value+=" </br>"
       a.value+=str(ii)

button.on_click(on_button_clicked)
display(output,button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button')?.click()},
    500);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
),a)
Rainb
  • 1,965
  • 11
  • 32
  • I suppose that you could use something other than matplotlib to draw the plots, would be nice, wouldn't it https://colab.research.google.com/drive/1t_wcE-NqoPO-dpnrB9VMQ0KUxR5e1rML?usp=sharing – Rainb Jan 31 '22 at 18:20