Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/helpers/schedule.py
4918 views
1
from typing import Tuple, List
2
3
4
class Schedule:
5
def __call__(self, x):
6
raise NotImplementedError()
7
8
9
class Flat(Schedule):
10
def __init__(self, value):
11
self.__value = value
12
13
def __call__(self, x):
14
return self.__value
15
16
def __str__(self):
17
return f"Schedule({self.__value})"
18
19
20
class Dynamic(Schedule):
21
def __init__(self, value):
22
self.__value = value
23
24
def __call__(self, x):
25
return self.__value
26
27
def update(self, value):
28
self.__value = value
29
30
def __str__(self):
31
return "Dynamic"
32
33
34
class Piecewise(Schedule):
35
"""
36
## Piecewise schedule
37
"""
38
39
def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None):
40
"""
41
### Initialize
42
43
`endpoints` is list of pairs `(x, y)`.
44
The values between endpoints are linearly interpolated.
45
`y` values outside the range covered by `x` are
46
`outside_value`.
47
"""
48
49
# `(x, y)` pairs should be sorted
50
indexes = [e[0] for e in endpoints]
51
assert indexes == sorted(indexes)
52
53
self._outside_value = outside_value
54
self._endpoints = endpoints
55
56
def __call__(self, x):
57
"""
58
### Find `y` for given `x`
59
"""
60
61
# iterate through each segment
62
for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]):
63
# interpolate if `x` is within the segment
64
if x1 <= x < x2:
65
dx = float(x - x1) / (x2 - x1)
66
return y1 + dx * (y2 - y1)
67
68
# return outside value otherwise
69
return self._outside_value
70
71
def __str__(self):
72
endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints])
73
return f"Schedule[{endpoints}, {self._outside_value}]"
74
75
76
class RelativePiecewise(Piecewise):
77
def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int):
78
endpoints = []
79
for e in relative_endpoits:
80
index = int(total_steps * e[0])
81
assert index >= 0
82
endpoints.append((index, e[1]))
83
84
super().__init__(endpoints, outside_value=relative_endpoits[-1][1])
85
86