Path: blob/master/labml_nn/neox/utils/cache.py
4921 views
"""1---2title: Cache for Intermediate Activations3summary: >4Cache for intermediate activations for faster inference.5---67# Cache for Intermediate Activations89During inference the model outputs token by token.10We use this simple cache to store key's and value's attention layers,11so that we don't have to recompute them for previous tokens.12"""1314from typing import Any151617class Cache:18"""19## Cache2021This maintains a key-value cache and queues push values and pop them in the same order.22The queues are useful since we have multiple attention layers.23"""2425def __init__(self):26self._cache = {}2728def clear_all(self):29"""30### Clear cache31"""32self._cache = {}3334def push(self, name: str, value: Any):35"""36### Push a value to a queue3738:param name: is the name of the queue39:param value: is the value to be pushed40"""4142# Create an empty queue if it's not present43if name not in self._cache:44self._cache[name] = []4546# Push to the queue47self._cache[name].append(value)4849def q_size(self, name):50"""51### Return the size of the queue5253:param name: is the name of the queue54:return: size of the queue if exists else None55"""5657if name not in self._cache:58return None5960if type(self._cache[name]) != list:61return None6263return len(self._cache[name])6465def pop(self, name: str):66"""67### Pop from a queue6869:param name: is the name of the queue70:return: the value71"""72return self._cache[name].pop(0)7374def set(self, key: str, value: Any):75"""76### Cache a value7778:param key: is the name of the value to be cached79:param value: is the value80"""81self._cache[key] = value8283def get(self, key: str, default: Any = None):84"""85### Retrieve a value from cache8687:param key: is the name used when caching88:param default: is the default value if the cache is empty89:return: the cached value90"""91return self._cache.get(key, default)9293def clear(self, key: str):94"""95### Clear a cache value9697:param key: is the name used when caching98"""99del self._cache[key]100101102# Singleton for cache103_INSTANCE = None104105106def get_cache() -> Cache:107"""108### Get the cache instance109110:return: the cache instance111"""112global _INSTANCE113114if _INSTANCE is None:115_INSTANCE = Cache()116117return _INSTANCE118119120