Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/utils/cache.py
4921 views
1
"""
2
---
3
title: Cache for Intermediate Activations
4
summary: >
5
Cache for intermediate activations for faster inference.
6
---
7
8
# Cache for Intermediate Activations
9
10
During inference the model outputs token by token.
11
We use this simple cache to store key's and value's attention layers,
12
so that we don't have to recompute them for previous tokens.
13
"""
14
15
from typing import Any
16
17
18
class Cache:
19
"""
20
## Cache
21
22
This maintains a key-value cache and queues push values and pop them in the same order.
23
The queues are useful since we have multiple attention layers.
24
"""
25
26
def __init__(self):
27
self._cache = {}
28
29
def clear_all(self):
30
"""
31
### Clear cache
32
"""
33
self._cache = {}
34
35
def push(self, name: str, value: Any):
36
"""
37
### Push a value to a queue
38
39
:param name: is the name of the queue
40
:param value: is the value to be pushed
41
"""
42
43
# Create an empty queue if it's not present
44
if name not in self._cache:
45
self._cache[name] = []
46
47
# Push to the queue
48
self._cache[name].append(value)
49
50
def q_size(self, name):
51
"""
52
### Return the size of the queue
53
54
:param name: is the name of the queue
55
:return: size of the queue if exists else None
56
"""
57
58
if name not in self._cache:
59
return None
60
61
if type(self._cache[name]) != list:
62
return None
63
64
return len(self._cache[name])
65
66
def pop(self, name: str):
67
"""
68
### Pop from a queue
69
70
:param name: is the name of the queue
71
:return: the value
72
"""
73
return self._cache[name].pop(0)
74
75
def set(self, key: str, value: Any):
76
"""
77
### Cache a value
78
79
:param key: is the name of the value to be cached
80
:param value: is the value
81
"""
82
self._cache[key] = value
83
84
def get(self, key: str, default: Any = None):
85
"""
86
### Retrieve a value from cache
87
88
:param key: is the name used when caching
89
:param default: is the default value if the cache is empty
90
:return: the cached value
91
"""
92
return self._cache.get(key, default)
93
94
def clear(self, key: str):
95
"""
96
### Clear a cache value
97
98
:param key: is the name used when caching
99
"""
100
del self._cache[key]
101
102
103
# Singleton for cache
104
_INSTANCE = None
105
106
107
def get_cache() -> Cache:
108
"""
109
### Get the cache instance
110
111
:return: the cache instance
112
"""
113
global _INSTANCE
114
115
if _INSTANCE is None:
116
_INSTANCE = Cache()
117
118
return _INSTANCE
119
120