package db
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"time"
"github.com/gitpod-io/gitpod/common-go/log"
driver_mysql "github.com/go-sql-driver/mysql"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/plugin/opentelemetry/tracing"
)
type ConnectionParams struct {
User string
Password string
Host string
Database string
CaCert string
}
func ConnectionParamsFromEnv() ConnectionParams {
return ConnectionParams{
User: os.Getenv("DB_USERNAME"),
Password: os.Getenv("DB_PASSWORD"),
Host: net.JoinHostPort(os.Getenv("DB_HOST"), os.Getenv("DB_PORT")),
Database: "gitpod",
CaCert: os.Getenv("DB_CA_CERT"),
}
}
func Connect(p ConnectionParams) (*gorm.DB, error) {
loc, err := time.LoadLocation("UTC")
if err != nil {
return nil, fmt.Errorf("Failed to load UTC location: %w", err)
}
cfg := driver_mysql.Config{
User: p.User,
Passwd: p.Password,
Net: "tcp",
Addr: p.Host,
DBName: p.Database,
Loc: loc,
AllowNativePasswords: true,
ParseTime: true,
}
if p.CaCert != "" {
rootCertPool := x509.NewCertPool()
if ok := rootCertPool.AppendCertsFromPEM([]byte(p.CaCert)); !ok {
return nil, fmt.Errorf("failed to append custom certificate for database connection")
}
tlsConfigName := "custom"
err = driver_mysql.RegisterTLSConfig(tlsConfigName, &tls.Config{
RootCAs: rootCertPool,
MinVersion: tls.VersionTLS12,
})
if err != nil {
return nil, fmt.Errorf("failed to register custom DB CA cert: %w", err)
}
cfg.TLSConfig = tlsConfigName
}
conn, err := gorm.Open(mysql.Open(cfg.FormatDSN()), &gorm.Config{
Logger: logger.New(log.Log, logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: true,
LogLevel: (func() logger.LogLevel {
switch log.Log.Level {
case logrus.PanicLevel, logrus.FatalLevel, logrus.ErrorLevel:
return logger.Error
case logrus.WarnLevel:
return logger.Warn
default:
return logger.Info
}
})(),
}),
})
if err != nil {
return nil, fmt.Errorf("failed to open db connection: %w", err)
}
err = conn.Use(tracing.NewPlugin())
if err != nil {
return nil, fmt.Errorf("failed to setup db tracing: %w", err)
}
return conn, nil
}