Path: blob/master/labml_nn/helpers/schedule.py
4918 views
from typing import Tuple, List123class Schedule:4def __call__(self, x):5raise NotImplementedError()678class Flat(Schedule):9def __init__(self, value):10self.__value = value1112def __call__(self, x):13return self.__value1415def __str__(self):16return f"Schedule({self.__value})"171819class Dynamic(Schedule):20def __init__(self, value):21self.__value = value2223def __call__(self, x):24return self.__value2526def update(self, value):27self.__value = value2829def __str__(self):30return "Dynamic"313233class Piecewise(Schedule):34"""35## Piecewise schedule36"""3738def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None):39"""40### Initialize4142`endpoints` is list of pairs `(x, y)`.43The values between endpoints are linearly interpolated.44`y` values outside the range covered by `x` are45`outside_value`.46"""4748# `(x, y)` pairs should be sorted49indexes = [e[0] for e in endpoints]50assert indexes == sorted(indexes)5152self._outside_value = outside_value53self._endpoints = endpoints5455def __call__(self, x):56"""57### Find `y` for given `x`58"""5960# iterate through each segment61for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]):62# interpolate if `x` is within the segment63if x1 <= x < x2:64dx = float(x - x1) / (x2 - x1)65return y1 + dx * (y2 - y1)6667# return outside value otherwise68return self._outside_value6970def __str__(self):71endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints])72return f"Schedule[{endpoints}, {self._outside_value}]"737475class RelativePiecewise(Piecewise):76def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int):77endpoints = []78for e in relative_endpoits:79index = int(total_steps * e[0])80assert index >= 081endpoints.append((index, e[1]))8283super().__init__(endpoints, outside_value=relative_endpoits[-1][1])848586