I am training a neural network for binary classification on Google Colab. Each epoch, I evaluate it on the validation dataset and calculate the percentages of true positives, false positives, true negatives, and false negatives. I want to see a live normalized stacked area chart (see https://altair-viz.github.io/gallery/normalized_stacked_area_chart.html for an explanation of what that is) of these four numbers. It should get updated with each epoch as the training process goes on. How do I achieve this? I am ready to use any 3rd party library.
Asked
Active
Viewed 305 times
0
-
colab is similar to other interactive environments such as jupyter afaik. maybe [this question](https://stackoverflow.com/questions/62630668/real-time-live-graphs-in-jupyter-notebook) could help. – gavin Apr 19 '22 at 03:13
-
I've found [this answer](https://stackoverflow.com/a/52866695/6797569) and I've tested briefly in a Google Colab and it worked. Not super fast, but depending on your case it could help. – Rodrigo Cava Apr 22 '22 at 22:49
1 Answers
1
Solution
I assume that you are saving your target metric after every epoch.
tf
: True Positivetn
: True Negativefp
: False Positivefn
: False Negative
Note: Ideally, you would create a list of dictionaries, where each dictionary consists of the metrics you want and the corresponding epoch number as follows:
# call this list results results = [ ... {"epoch": 0, "tp": 200, "tn": 80, "fp": 18, "fn": 5} ... ]
Use this list to create a pandas dataframe and then use the custom plotting function as shown below.
df = pf.DataFrame(results).T
The solution below uses Plotly
library to create the desired chart. Here is a jupyter notebook (with a google colab link) to quickly check-out the proposed solution.
Make Interactive Stacked Normalized Area Chart
figure_title="Confusion Matrix Evolution over Model Training Epochs"
columns = ["tp", "tn", "fp", "fn"]
colors = ['#d7191c','#fdae61','#abdda4','#2b83ba']
palette = dict((column, color) for column, color in zip(columns, colors))
# create interactive chart with user-defined function
make_stacked_normalized_chart(
df,
x="epoch",
columns=columns,
palette=palette,
figure_title=figure_title,
)
Define Custom Plotting Function
Here we define a custom function (make_stacked_normalized_chart()
) to create an interactive-stacked-normalized-area-chart.
import plotly.graph_objects as go
from typing import List, Dict
def make_stacked_normalized_chart(df: pd.DataFrame, x: str,
columns: List[str],
palette: Dict[str, str],
figure_title: str="Figure Title"):
"""Create a stacked normalized interactive chart with Plotly library."""
x_label = x
x = df[x_label]
fig = go.Figure()
def add_trace(column: str):
fig.add_trace(go.Scatter(
x=x, y=df[column],
text=column, # set the name shown while hovering over
name=column, # set the name in the legend
# fill='toself',
hoveron = 'points+fills', # select where hover is active
hoverinfo='text+x+y',
mode='lines',
line=dict(width=0.5, color=palette.get(column)),
stackgroup='one', # define stack group
groupnorm='percent', # sets the normalization for the sum of the stackgroup
))
for column in columns:
add_trace(column)
fig.update_layout(
title_text=figure_title,
showlegend=True,
xaxis=dict(
type="category",
title=x_label,
),
yaxis=dict(
type='linear',
range=[1, 100],
ticksuffix='%',
),
)
fig.show()
Dummy Data
We will use the following data to demonstrate the interactive-stacked-normalized-area-chart.
import numpy as np
import pandas as pd
np.random.seed(42)
nrows = 100
x = np.arange(nrows)
tp = 60 + np.arange(nrows) + np.random.randint(0, 20, nrows)
tn = 25 + np.arange(nrows) + np.random.randint(0, 20, nrows)
fp = np.random.randint(2, 6, nrows) + np.random.randint(0, 8, nrows)
fn = np.random.randint(4, 7, nrows) + np.random.randint(3, 6, nrows)
df = pd.DataFrame({"epoch": x, "tp": tp, "tn": tn, "fp": fp, "fn": fn})

CypherX
- 7,019
- 3
- 25
- 37
-
@CrabMan Please take a look at this and let me know if you have any questions. – CypherX Apr 23 '22 at 09:21
-
hi, I've looked at your solution. I'm sorry, but it doesn't do what I want it to do. I want the plot to be updated live as the training code runs (as the epochs pass). – CrabMan Apr 25 '22 at 16:15