package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"os"
"path/filepath"
"runtime/debug"
"time"
"github.com/gitpod-io/gitpod/common-go/log"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"golang.org/x/xerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
)
const maxMsgSize = 1024 * 1024 * 16
var defaultClientOptionsConfig struct {
Metrics *grpc_prometheus.ClientMetrics
}
func ClientMetrics() prometheus.Collector {
res := grpc_prometheus.NewClientMetrics()
defaultClientOptionsConfig.Metrics = res
return res
}
func DefaultClientOptions() []grpc.DialOption {
bfConf := backoff.DefaultConfig
bfConf.MaxDelay = 5 * time.Second
var (
unaryInterceptor = []grpc.UnaryClientInterceptor{
grpc_opentracing.UnaryClientInterceptor(grpc_opentracing.WithTracer(opentracing.GlobalTracer())),
}
streamInterceptor = []grpc.StreamClientInterceptor{
grpc_opentracing.StreamClientInterceptor(grpc_opentracing.WithTracer(opentracing.GlobalTracer())),
}
)
if defaultClientOptionsConfig.Metrics != nil {
unaryInterceptor = append(unaryInterceptor, defaultClientOptionsConfig.Metrics.UnaryClientInterceptor())
streamInterceptor = append(streamInterceptor, defaultClientOptionsConfig.Metrics.StreamClientInterceptor())
}
res := []grpc.DialOption{
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(unaryInterceptor...)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(streamInterceptor...)),
grpc.WithBlock(),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: bfConf,
}),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 10 * time.Second,
Timeout: time.Second,
PermitWithoutStream: true,
}),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)),
}
return res
}
func DefaultServerOptions() []grpc.ServerOption {
return ServerOptionsWithInterceptors([]grpc.StreamServerInterceptor{}, []grpc.UnaryServerInterceptor{})
}
func ServerOptionsWithInterceptors(stream []grpc.StreamServerInterceptor, unary []grpc.UnaryServerInterceptor) []grpc.ServerOption {
tracingFilterFunc := grpc_opentracing.WithFilterFunc(func(ctx context.Context, fullMethodName string) bool {
return fullMethodName != "/grpc.health.v1.Health/Check"
})
stream = append(stream,
grpc_opentracing.StreamServerInterceptor(tracingFilterFunc),
grpc_recovery.StreamServerInterceptor(),
)
unary = append(unary,
grpc_opentracing.UnaryServerInterceptor(tracingFilterFunc),
grpc_recovery.UnaryServerInterceptor(grpc_recovery.WithRecoveryHandlerContext(
func(ctx context.Context, p interface{}) error {
log.WithField("stack", string(debug.Stack())).Errorf("[PANIC] %s", p)
return status.Errorf(codes.Internal, "%s", p)
},
)),
)
return []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 10 * time.Second,
PermitWithoutStream: true,
}),
grpc.MaxRecvMsgSize(maxMsgSize),
grpc.KeepaliveParams(keepalive.ServerParameters{
MaxConnectionIdle: 30 * time.Minute,
}),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(stream...)),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unary...)),
}
}
func SetupLogging() {
grpclog.SetLoggerV2(grpclog.NewLoggerV2(
log.WithField("component", "grpc").WriterLevel(logrus.DebugLevel),
log.WithField("component", "grpc").WriterLevel(logrus.WarnLevel),
log.WithField("component", "grpc").WriterLevel(logrus.ErrorLevel),
))
}
type TLSConfigOption func(*tlsConfigOptions) error
type tlsConfigOptions struct {
ClientAuth tls.ClientAuthType
ServerName string
RootCAs bool
ClientCAs bool
}
func WithClientAuth(authType tls.ClientAuthType) TLSConfigOption {
return func(ico *tlsConfigOptions) error {
ico.ClientAuth = authType
return nil
}
}
func WithServerName(serverName string) TLSConfigOption {
return func(ico *tlsConfigOptions) error {
ico.ServerName = serverName
return nil
}
}
func WithSetRootCAs(rootCAs bool) TLSConfigOption {
return func(ico *tlsConfigOptions) error {
ico.RootCAs = rootCAs
return nil
}
}
func WithSetClientCAs(clientCAs bool) TLSConfigOption {
return func(ico *tlsConfigOptions) error {
ico.ClientCAs = clientCAs
return nil
}
}
func ClientAuthTLSConfig(authority, certificate, privateKey string, opts ...TLSConfigOption) (*tls.Config, error) {
if root := os.Getenv("TELEPRESENCE_ROOT"); root != "" {
authority = filepath.Join(root, authority)
certificate = filepath.Join(root, certificate)
privateKey = filepath.Join(root, privateKey)
}
tlsCertificate, err := tls.LoadX509KeyPair(certificate, privateKey)
if err != nil {
return nil, xerrors.Errorf("cannot load TLS certificate: %w", err)
}
certPool := x509.NewCertPool()
ca, err := os.ReadFile(authority)
if err != nil {
return nil, xerrors.Errorf("cannot not read ca certificate: %w", err)
}
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, xerrors.Errorf("failed to append ca certs")
}
options := tlsConfigOptions{}
for _, o := range opts {
err := o(&options)
if err != nil {
return nil, err
}
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCertificate},
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256},
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
NextProtos: []string{"h2"},
}
tlsConfig.ServerName = options.ServerName
if options.ClientAuth != tls.NoClientCert {
log.WithField("clientAuth", options.ClientAuth).Info("enabling client authentication")
tlsConfig.ClientAuth = options.ClientAuth
}
if options.ClientCAs {
tlsConfig.ClientCAs = certPool
}
if options.RootCAs {
tlsConfig.RootCAs = certPool
}
return tlsConfig, nil
}