Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/common-go/grpc/ratelimit.go
2498 views
1
// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package grpc
6
7
import (
8
"context"
9
"strconv"
10
"strings"
11
"time"
12
13
"golang.org/x/time/rate"
14
"google.golang.org/grpc"
15
"google.golang.org/grpc/codes"
16
"google.golang.org/grpc/status"
17
"google.golang.org/protobuf/proto"
18
"google.golang.org/protobuf/reflect/protoreflect"
19
20
"github.com/gitpod-io/gitpod/common-go/util"
21
lru "github.com/hashicorp/golang-lru"
22
"github.com/prometheus/client_golang/prometheus"
23
)
24
25
type keyFunc func(req interface{}) (string, error)
26
27
// RateLimit configures the reate limit for a function
28
type RateLimit struct {
29
Block bool `json:"block"`
30
BucketSize uint `json:"bucketSize"`
31
// RefillInterval is the rate at which a new token gets added to the bucket.
32
// Note that this does _not_ completely refill the bucket, only one token gets added,
33
// so effectively this is the rate at which requests can be made.
34
RefillInterval util.Duration `json:"refillInterval"`
35
36
// Key is the proto field name to rate limit on. Each unique value of this
37
// field gets its own rate limit bucket. Must be a String, Enum, or Boolean field.
38
// Can be a composite key by separating fields by comma, e.g. `foo.bar,foo.baz`
39
Key string `json:"key,omitempty"`
40
// KeyCacheSize is the max number of buckets kept in a LRU cache.
41
KeyCacheSize uint `json:"keyCacheSize,omitempty"`
42
}
43
44
func (r RateLimit) Limiter() *rate.Limiter {
45
return rate.NewLimiter(rate.Every(time.Duration(r.RefillInterval)), int(r.BucketSize))
46
}
47
48
// NewRatelimitingInterceptor creates a new rate limiting interceptor
49
func NewRatelimitingInterceptor(f map[string]RateLimit) RatelimitingInterceptor {
50
callCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
51
Namespace: "grpc",
52
Subsystem: "server",
53
Name: "rate_limiter_calls_total",
54
}, []string{"grpc_method", "rate_limited"})
55
cacheHitCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
56
Namespace: "grpc",
57
Subsystem: "server",
58
Name: "rate_limiter_cache_hit_total",
59
}, []string{"grpc_method"})
60
61
funcs := make(map[string]*ratelimitedFunction, len(f))
62
for name, fnc := range f {
63
var (
64
keyedLimit *lru.Cache
65
key keyFunc
66
)
67
if fnc.Key != "" && fnc.KeyCacheSize > 0 {
68
keyedLimit, _ = lru.New(int(fnc.KeyCacheSize))
69
key = fieldAccessKey(fnc.Key)
70
}
71
72
funcs[name] = &ratelimitedFunction{
73
RateLimit: fnc,
74
GlobalLimit: fnc.Limiter(),
75
Key: key,
76
KeyedLimit: keyedLimit,
77
RateLimitedTotal: callCounter.WithLabelValues(name, "true"),
78
NotRateLimitedTotal: callCounter.WithLabelValues(name, "false"),
79
CacheMissTotal: cacheHitCounter.WithLabelValues(name),
80
}
81
}
82
return RatelimitingInterceptor{
83
functions: funcs,
84
collectors: []prometheus.Collector{callCounter, cacheHitCounter},
85
}
86
}
87
88
func fieldAccessKey(key string) keyFunc {
89
fields := strings.Split(key, ",")
90
paths := make([][]string, len(fields))
91
for i, field := range fields {
92
paths[i] = strings.Split(field, ".")
93
}
94
return func(req interface{}) (string, error) {
95
msg, ok := req.(proto.Message)
96
if !ok {
97
return "", status.Errorf(codes.Internal, "request was not a protobuf message")
98
}
99
100
var composite string
101
for i, field := range fields {
102
val, ok := getFieldValue(msg.ProtoReflect(), paths[i])
103
if !ok {
104
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", field)
105
}
106
// It's technically possible that `|` is part of one of the field values, and therefore could cause collisions
107
// in composite keys, e.g. values (`a|`, `b`), and (`a`, `|b`) would result in the same composite key `a||b`
108
// and share the rate limit. This is highly unlikely though given the current fields we rate limit on and
109
// otherwise unlikely to cause issues.
110
composite += "|" + val
111
}
112
113
return composite, nil
114
}
115
}
116
117
func getFieldValue(msg protoreflect.Message, path []string) (val string, ok bool) {
118
if len(path) == 0 {
119
return "", false
120
}
121
122
field := msg.Descriptor().Fields().ByName(protoreflect.Name(path[0]))
123
if field == nil {
124
return "", false
125
}
126
if len(path) > 1 {
127
if field.Kind() != protoreflect.MessageKind {
128
// we should go deeper but the field is not a message
129
return "", false
130
}
131
child := msg.Get(field).Message()
132
return getFieldValue(child, path[1:])
133
}
134
135
switch field.Kind() {
136
case protoreflect.StringKind:
137
return msg.Get(field).String(), true
138
case protoreflect.EnumKind:
139
enumNum := msg.Get(field).Enum()
140
return strconv.Itoa(int(enumNum)), true
141
case protoreflect.BoolKind:
142
if msg.Get(field).Bool() {
143
return "t", true
144
} else {
145
return "f", true
146
}
147
148
default:
149
// we only support string and enum fields
150
return "", false
151
}
152
}
153
154
// RatelimitingInterceptor limits how often a gRPC function may be called. If the limit has been
155
// exceeded, we'll return resource exhausted.
156
type RatelimitingInterceptor struct {
157
functions map[string]*ratelimitedFunction
158
collectors []prometheus.Collector
159
}
160
161
var _ prometheus.Collector = RatelimitingInterceptor{}
162
163
func (r RatelimitingInterceptor) Describe(d chan<- *prometheus.Desc) {
164
for _, c := range r.collectors {
165
c.Describe(d)
166
}
167
}
168
169
func (r RatelimitingInterceptor) Collect(m chan<- prometheus.Metric) {
170
for _, c := range r.collectors {
171
c.Collect(m)
172
}
173
}
174
175
type counter interface {
176
Inc()
177
}
178
179
type ratelimitedFunction struct {
180
RateLimit RateLimit
181
182
GlobalLimit *rate.Limiter
183
Key keyFunc
184
KeyedLimit *lru.Cache
185
186
RateLimitedTotal counter
187
NotRateLimitedTotal counter
188
CacheMissTotal counter
189
}
190
191
// UnaryInterceptor creates a unary interceptor that implements the rate limiting
192
func (r RatelimitingInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
193
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
194
f, ok := r.functions[info.FullMethod]
195
if !ok {
196
return handler(ctx, req)
197
}
198
199
var limit *rate.Limiter
200
if f.Key == nil {
201
limit = f.GlobalLimit
202
} else {
203
key, err := f.Key(req)
204
if err != nil {
205
return nil, err
206
}
207
208
found, _ := f.KeyedLimit.ContainsOrAdd(key, f.RateLimit.Limiter())
209
if !found && f.CacheMissTotal != nil {
210
f.CacheMissTotal.Inc()
211
}
212
v, _ := f.KeyedLimit.Get(key)
213
limit = v.(*rate.Limiter)
214
}
215
216
var blocked bool
217
defer func() {
218
if blocked && f.RateLimitedTotal != nil {
219
f.RateLimitedTotal.Inc()
220
} else if !blocked && f.NotRateLimitedTotal != nil {
221
f.NotRateLimitedTotal.Inc()
222
}
223
}()
224
if f.RateLimit.Block {
225
err := limit.Wait(ctx)
226
if err == context.Canceled {
227
blocked = true
228
return nil, err
229
}
230
if err != nil {
231
blocked = true
232
return nil, status.Error(codes.ResourceExhausted, err.Error())
233
}
234
} else if !limit.Allow() {
235
blocked = true
236
return nil, status.Error(codes.ResourceExhausted, "too many requests")
237
}
238
239
return handler(ctx, req)
240
}
241
}
242
243