Path: blob/master/labml_nn/cfr/analytics.py
4918 views
from typing import List12import altair as alt3import numpy as np45from labml import analytics6from labml.analytics import IndicatorCollection789def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):10normalized = []1112for i in range(len(means)):13total = np.zeros_like(means[i])14for j, n in enumerate(names):15if n[-1][:-1] == names[i][-1][:-1]:16total += means[j]17normalized.append(means[i] / (total + np.finfo(float).eps))1819return normalized202122def plot_infosets(indicators: IndicatorCollection, *,23is_normalize: bool = True,24width: int = 600,25height: int = 300):26data, names = analytics.indicator_data(indicators)27step = [d[:, 0] for d in data]28means = [d[:, 5] for d in data]2930if is_normalize:31normalized = calculate_percentages(means, names)32else:33normalized = means3435common = names[0][-1]36for i, n in enumerate(names):37n = n[-1]38if len(n) < len(common):39common = common[:len(n)]40for j in range(len(common)):41if common[j] != n[j]:42common = common[:j]43break4445table = []46for i, n in enumerate(names):47for j, v in zip(step[i], normalized[i]):48table.append({49'series': n[-1][len(common):],50'step': j,51'value': v52})5354table = alt.Data(values=table)5556selection = alt.selection_multi(fields=['series'], bind='legend')5758return alt.Chart(table).mark_line().encode(59alt.X('step:Q'),60alt.Y('value:Q'),61alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),62opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))63).add_selection(64selection65).properties(width=width, height=height)666768