Path: blob/main/components/common-go/grpc/ratelimit.go
2498 views
// Copyright (c) 2020 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package grpc56import (7"context"8"strconv"9"strings"10"time"1112"golang.org/x/time/rate"13"google.golang.org/grpc"14"google.golang.org/grpc/codes"15"google.golang.org/grpc/status"16"google.golang.org/protobuf/proto"17"google.golang.org/protobuf/reflect/protoreflect"1819"github.com/gitpod-io/gitpod/common-go/util"20lru "github.com/hashicorp/golang-lru"21"github.com/prometheus/client_golang/prometheus"22)2324type keyFunc func(req interface{}) (string, error)2526// RateLimit configures the reate limit for a function27type RateLimit struct {28Block bool `json:"block"`29BucketSize uint `json:"bucketSize"`30// RefillInterval is the rate at which a new token gets added to the bucket.31// Note that this does _not_ completely refill the bucket, only one token gets added,32// so effectively this is the rate at which requests can be made.33RefillInterval util.Duration `json:"refillInterval"`3435// Key is the proto field name to rate limit on. Each unique value of this36// field gets its own rate limit bucket. Must be a String, Enum, or Boolean field.37// Can be a composite key by separating fields by comma, e.g. `foo.bar,foo.baz`38Key string `json:"key,omitempty"`39// KeyCacheSize is the max number of buckets kept in a LRU cache.40KeyCacheSize uint `json:"keyCacheSize,omitempty"`41}4243func (r RateLimit) Limiter() *rate.Limiter {44return rate.NewLimiter(rate.Every(time.Duration(r.RefillInterval)), int(r.BucketSize))45}4647// NewRatelimitingInterceptor creates a new rate limiting interceptor48func NewRatelimitingInterceptor(f map[string]RateLimit) RatelimitingInterceptor {49callCounter := prometheus.NewCounterVec(prometheus.CounterOpts{50Namespace: "grpc",51Subsystem: "server",52Name: "rate_limiter_calls_total",53}, []string{"grpc_method", "rate_limited"})54cacheHitCounter := prometheus.NewCounterVec(prometheus.CounterOpts{55Namespace: "grpc",56Subsystem: "server",57Name: "rate_limiter_cache_hit_total",58}, []string{"grpc_method"})5960funcs := make(map[string]*ratelimitedFunction, len(f))61for name, fnc := range f {62var (63keyedLimit *lru.Cache64key keyFunc65)66if fnc.Key != "" && fnc.KeyCacheSize > 0 {67keyedLimit, _ = lru.New(int(fnc.KeyCacheSize))68key = fieldAccessKey(fnc.Key)69}7071funcs[name] = &ratelimitedFunction{72RateLimit: fnc,73GlobalLimit: fnc.Limiter(),74Key: key,75KeyedLimit: keyedLimit,76RateLimitedTotal: callCounter.WithLabelValues(name, "true"),77NotRateLimitedTotal: callCounter.WithLabelValues(name, "false"),78CacheMissTotal: cacheHitCounter.WithLabelValues(name),79}80}81return RatelimitingInterceptor{82functions: funcs,83collectors: []prometheus.Collector{callCounter, cacheHitCounter},84}85}8687func fieldAccessKey(key string) keyFunc {88fields := strings.Split(key, ",")89paths := make([][]string, len(fields))90for i, field := range fields {91paths[i] = strings.Split(field, ".")92}93return func(req interface{}) (string, error) {94msg, ok := req.(proto.Message)95if !ok {96return "", status.Errorf(codes.Internal, "request was not a protobuf message")97}9899var composite string100for i, field := range fields {101val, ok := getFieldValue(msg.ProtoReflect(), paths[i])102if !ok {103return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", field)104}105// It's technically possible that `|` is part of one of the field values, and therefore could cause collisions106// in composite keys, e.g. values (`a|`, `b`), and (`a`, `|b`) would result in the same composite key `a||b`107// and share the rate limit. This is highly unlikely though given the current fields we rate limit on and108// otherwise unlikely to cause issues.109composite += "|" + val110}111112return composite, nil113}114}115116func getFieldValue(msg protoreflect.Message, path []string) (val string, ok bool) {117if len(path) == 0 {118return "", false119}120121field := msg.Descriptor().Fields().ByName(protoreflect.Name(path[0]))122if field == nil {123return "", false124}125if len(path) > 1 {126if field.Kind() != protoreflect.MessageKind {127// we should go deeper but the field is not a message128return "", false129}130child := msg.Get(field).Message()131return getFieldValue(child, path[1:])132}133134switch field.Kind() {135case protoreflect.StringKind:136return msg.Get(field).String(), true137case protoreflect.EnumKind:138enumNum := msg.Get(field).Enum()139return strconv.Itoa(int(enumNum)), true140case protoreflect.BoolKind:141if msg.Get(field).Bool() {142return "t", true143} else {144return "f", true145}146147default:148// we only support string and enum fields149return "", false150}151}152153// RatelimitingInterceptor limits how often a gRPC function may be called. If the limit has been154// exceeded, we'll return resource exhausted.155type RatelimitingInterceptor struct {156functions map[string]*ratelimitedFunction157collectors []prometheus.Collector158}159160var _ prometheus.Collector = RatelimitingInterceptor{}161162func (r RatelimitingInterceptor) Describe(d chan<- *prometheus.Desc) {163for _, c := range r.collectors {164c.Describe(d)165}166}167168func (r RatelimitingInterceptor) Collect(m chan<- prometheus.Metric) {169for _, c := range r.collectors {170c.Collect(m)171}172}173174type counter interface {175Inc()176}177178type ratelimitedFunction struct {179RateLimit RateLimit180181GlobalLimit *rate.Limiter182Key keyFunc183KeyedLimit *lru.Cache184185RateLimitedTotal counter186NotRateLimitedTotal counter187CacheMissTotal counter188}189190// UnaryInterceptor creates a unary interceptor that implements the rate limiting191func (r RatelimitingInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {192return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {193f, ok := r.functions[info.FullMethod]194if !ok {195return handler(ctx, req)196}197198var limit *rate.Limiter199if f.Key == nil {200limit = f.GlobalLimit201} else {202key, err := f.Key(req)203if err != nil {204return nil, err205}206207found, _ := f.KeyedLimit.ContainsOrAdd(key, f.RateLimit.Limiter())208if !found && f.CacheMissTotal != nil {209f.CacheMissTotal.Inc()210}211v, _ := f.KeyedLimit.Get(key)212limit = v.(*rate.Limiter)213}214215var blocked bool216defer func() {217if blocked && f.RateLimitedTotal != nil {218f.RateLimitedTotal.Inc()219} else if !blocked && f.NotRateLimitedTotal != nil {220f.NotRateLimitedTotal.Inc()221}222}()223if f.RateLimit.Block {224err := limit.Wait(ctx)225if err == context.Canceled {226blocked = true227return nil, err228}229if err != nil {230blocked = true231return nil, status.Error(codes.ResourceExhausted, err.Error())232}233} else if !limit.Allow() {234blocked = true235return nil, status.Error(codes.ResourceExhausted, "too many requests")236}237238return handler(ctx, req)239}240}241242243