Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/cfr/analytics.py
4918 views
1
from typing import List
2
3
import altair as alt
4
import numpy as np
5
6
from labml import analytics
7
from labml.analytics import IndicatorCollection
8
9
10
def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):
11
normalized = []
12
13
for i in range(len(means)):
14
total = np.zeros_like(means[i])
15
for j, n in enumerate(names):
16
if n[-1][:-1] == names[i][-1][:-1]:
17
total += means[j]
18
normalized.append(means[i] / (total + np.finfo(float).eps))
19
20
return normalized
21
22
23
def plot_infosets(indicators: IndicatorCollection, *,
24
is_normalize: bool = True,
25
width: int = 600,
26
height: int = 300):
27
data, names = analytics.indicator_data(indicators)
28
step = [d[:, 0] for d in data]
29
means = [d[:, 5] for d in data]
30
31
if is_normalize:
32
normalized = calculate_percentages(means, names)
33
else:
34
normalized = means
35
36
common = names[0][-1]
37
for i, n in enumerate(names):
38
n = n[-1]
39
if len(n) < len(common):
40
common = common[:len(n)]
41
for j in range(len(common)):
42
if common[j] != n[j]:
43
common = common[:j]
44
break
45
46
table = []
47
for i, n in enumerate(names):
48
for j, v in zip(step[i], normalized[i]):
49
table.append({
50
'series': n[-1][len(common):],
51
'step': j,
52
'value': v
53
})
54
55
table = alt.Data(values=table)
56
57
selection = alt.selection_multi(fields=['series'], bind='legend')
58
59
return alt.Chart(table).mark_line().encode(
60
alt.X('step:Q'),
61
alt.Y('value:Q'),
62
alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),
63
opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))
64
).add_selection(
65
selection
66
).properties(width=width, height=height)
67
68