package vault
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/grafana/agent/component"
"github.com/grafana/agent/pkg/river/rivertypes"
"github.com/oklog/run"
vault "github.com/hashicorp/vault/api"
)
func init() {
component.Register(component.Registration{
Name: "remote.vault",
Args: Arguments{},
Exports: Exports{},
Build: func(opts component.Options, args component.Arguments) (component.Component, error) {
return New(opts, args.(Arguments))
},
})
}
type Arguments struct {
Server string `river:"server,attr"`
Namespace string `river:"namespace,attr,optional"`
Path string `river:"path,attr"`
RereadFrequency time.Duration `river:"reread_frequency,attr,optional"`
ClientOptions ClientOptions `river:"client_options,block,optional"`
Auth []AuthArguments `river:"auth,enum,optional"`
}
var DefaultArguments = Arguments{
ClientOptions: ClientOptions{
MinRetryWait: 1000 * time.Millisecond,
MaxRetryWait: 1500 * time.Millisecond,
MaxRetries: 2,
Timeout: 60 * time.Second,
},
}
func (a *Arguments) client() (*vault.Client, error) {
cfg := vault.DefaultConfig()
cfg.Address = a.Server
cfg.MinRetryWait = a.ClientOptions.MinRetryWait
cfg.MaxRetryWait = a.ClientOptions.MaxRetryWait
cfg.MaxRetries = a.ClientOptions.MaxRetries
cfg.Timeout = a.ClientOptions.Timeout
return vault.NewClient(cfg)
}
func (a *Arguments) UnmarshalRiver(f func(interface{}) error) error {
*a = DefaultArguments
type arguments Arguments
if err := f((*arguments)(a)); err != nil {
return err
}
if len(a.Auth) == 0 {
return fmt.Errorf("exactly one auth.* block must be specified; found none")
} else if len(a.Auth) > 1 {
return fmt.Errorf("exactly one auth.* block must be specified; found %d", len(a.Auth))
}
if a.ClientOptions.Timeout == 0 {
return fmt.Errorf("client_options.timeout must be greater than 0")
}
return nil
}
func (a *Arguments) authMethod() authMethod {
if len(a.Auth) != 1 {
panic(fmt.Sprintf("remote.vault: found %d auth types, expected 1", len(a.Auth)))
}
return a.Auth[0].authMethod()
}
func (a *Arguments) secretStore(cli *vault.Client) secretStore {
return &kvStore{c: cli}
}
type ClientOptions struct {
MinRetryWait time.Duration `river:"min_retry_wait,attr,optional"`
MaxRetryWait time.Duration `river:"max_retry_wait,attr,optional"`
MaxRetries int `river:"max_retries,attr,optional"`
Timeout time.Duration `river:"timeout,attr,optional"`
}
type Exports struct {
Data map[string]rivertypes.Secret `river:"data,attr"`
}
type Component struct {
opts component.Options
log log.Logger
metrics *metrics
mut sync.RWMutex
args Arguments
secretManager *tokenManager
authManager *tokenManager
}
var (
_ component.Component = (*Component)(nil)
_ component.HealthComponent = (*Component)(nil)
_ component.DebugComponent = (*Component)(nil)
)
func New(opts component.Options, args Arguments) (*Component, error) {
c := &Component{
opts: opts,
log: opts.Logger,
metrics: newMetrics(opts.Registerer),
}
if err := c.Update(args); err != nil {
return nil, err
}
return c, nil
}
func (c *Component) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var rg run.Group
rg.Add(func() error {
c.secretManager.Run(ctx)
return nil
}, func(_ error) {
cancel()
})
rg.Add(func() error {
c.authManager.Run(ctx)
return nil
}, func(_ error) {
cancel()
})
return rg.Run()
}
func (c *Component) Update(args component.Arguments) error {
newArgs := args.(Arguments)
newClient, err := newArgs.client()
if err != nil {
return err
}
c.mut.Lock()
c.args = newArgs
c.mut.Unlock()
if c.authManager == nil {
mgr, err := newTokenManager(tokenManagerOptions{
Log: log.With(c.log, "token_type", "auth"),
Client: newClient,
Getter: c.getAuthToken,
ReadCounter: c.metrics.authTotal,
RefreshCounter: c.metrics.authLeaseRenewalTotal,
})
if err != nil {
return err
}
c.authManager = mgr
} else {
c.authManager.SetClient(newClient)
}
if c.secretManager == nil {
mgr, err := newTokenManager(tokenManagerOptions{
Log: log.With(c.log, "token_type", "secret"),
Client: newClient,
Getter: c.getSecret,
RefreshInterval: newArgs.RereadFrequency,
ReadCounter: c.metrics.secretReadTotal,
RefreshCounter: c.metrics.secretLeaseRenewalTotal,
})
if err != nil {
return err
}
c.secretManager = mgr
} else {
c.secretManager.SetClient(newClient)
c.secretManager.SetRefreshInterval(newArgs.RereadFrequency)
}
return nil
}
func (c *Component) getAuthToken(ctx context.Context, cli *vault.Client) (*vault.Secret, error) {
c.mut.RLock()
defer c.mut.RUnlock()
authMethod := c.args.authMethod()
return authMethod.vaultAuthenticate(ctx, cli)
}
func (c *Component) getSecret(ctx context.Context, cli *vault.Client) (*vault.Secret, error) {
c.mut.RLock()
defer c.mut.RUnlock()
store := c.args.secretStore(cli)
secret, err := store.Read(ctx, &c.args)
if err != nil {
return nil, err
}
c.exportSecret(secret)
return secret, nil
}
func (c *Component) exportSecret(secret *vault.Secret) {
newExports := Exports{
Data: make(map[string]rivertypes.Secret),
}
for key, value := range secret.Data {
switch value := value.(type) {
case string:
newExports.Data[key] = rivertypes.Secret(value)
case []byte:
newExports.Data[key] = rivertypes.Secret(value)
default:
level.Warn(c.log).Log(
"msg", "found field in secret which cannot be converted into a string",
"key", key,
"type", fmt.Sprintf("%T", value),
)
}
}
c.opts.OnStateChange(newExports)
}
func (c *Component) CurrentHealth() component.Health {
return component.LeastHealthy(
c.authManager.CurrentHealth(),
c.secretManager.CurrentHealth(),
)
}
func (c *Component) DebugInfo() interface{} {
return debugInfo{
AuthToken: c.authManager.DebugInfo(),
Secret: c.secretManager.DebugInfo(),
}
}
type debugInfo struct {
AuthToken secretInfo `river:"auth_token,block"`
Secret secretInfo `river:"secret,block"`
}