package s3
import (
"crypto/tls"
"fmt"
"net/http"
"strings"
"sync"
"time"
"context"
"github.com/aws/aws-sdk-go-v2/aws"
aws_config "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/grafana/agent/component"
"github.com/grafana/agent/pkg/river/rivertypes"
"github.com/prometheus/client_golang/prometheus"
)
func init() {
component.Register(component.Registration{
Name: "remote.s3",
Args: Arguments{},
Exports: Exports{},
Build: func(opts component.Options, args component.Arguments) (component.Component, error) {
return New(opts, args.(Arguments))
},
})
}
type S3 struct {
mut sync.Mutex
opts component.Options
args Arguments
health component.Health
content string
watcher *watcher
updateChan chan result
s3Errors prometheus.Counter
lastAccessed prometheus.Gauge
}
var (
_ component.Component = (*S3)(nil)
_ component.HealthComponent = (*S3)(nil)
)
func New(o component.Options, args Arguments) (*S3, error) {
s3cfg, err := generateS3Config(args)
if err != nil {
return nil, err
}
s3Client := s3.NewFromConfig(*s3cfg, func(s3o *s3.Options) {
s3o.UsePathStyle = args.Options.UsePathStyle
})
bucket, file := getPathBucketAndFile(args.Path)
s := &S3{
opts: o,
args: args,
health: component.Health{},
updateChan: make(chan result),
s3Errors: prometheus.NewCounter(prometheus.CounterOpts{
Name: "agent_remote_s3_errors_total",
Help: "The number of errors while accessing s3",
}),
lastAccessed: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "agent_remote_s3_timestamp_last_accessed_unix_seconds",
Help: "The last successful access in unix seconds",
}),
}
w := newWatcher(bucket, file, s.updateChan, args.PollFrequency, s3Client)
s.watcher = w
err = o.Registerer.Register(s.s3Errors)
if err != nil {
return nil, err
}
err = o.Registerer.Register(s.lastAccessed)
if err != nil {
return nil, err
}
content, err := w.downloadSynchronously()
s.handleContentPolling(content, err)
return s, nil
}
func (s *S3) Run(ctx context.Context) error {
go s.handleContentUpdate(ctx)
go s.watcher.run(ctx)
<-ctx.Done()
return nil
}
func (s *S3) Update(args component.Arguments) error {
newArgs := args.(Arguments)
s3cfg, err := generateS3Config(newArgs)
if err != nil {
return nil
}
s3Client := s3.NewFromConfig(*s3cfg, func(s3o *s3.Options) {
s3o.UsePathStyle = newArgs.Options.UsePathStyle
})
bucket, file := getPathBucketAndFile(newArgs.Path)
s.mut.Lock()
defer s.mut.Unlock()
s.args = newArgs
s.watcher.updateValues(bucket, file, newArgs.PollFrequency, s3Client)
return nil
}
func (s *S3) CurrentHealth() component.Health {
s.mut.Lock()
defer s.mut.Unlock()
return s.health
}
func generateS3Config(args Arguments) (*aws.Config, error) {
configOptions := make([]func(*aws_config.LoadOptions) error, 0)
if args.Options.Endpoint != "" {
endFunc := aws.EndpointResolverWithOptionsFunc(func(service, region string, _ ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: args.Options.Endpoint}, nil
})
endResolver := aws_config.WithEndpointResolverWithOptions(endFunc)
configOptions = append(configOptions, endResolver)
}
if args.Options.DisableSSL {
httpOverride := aws_config.WithHTTPClient(
&http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: args.Options.DisableSSL,
},
},
},
)
configOptions = append(configOptions, httpOverride)
}
if args.Options.AccessKey != "" {
if args.Options.Secret == "" {
return nil, fmt.Errorf("if accesskey or secret are specified then the other must also be specified")
}
credFunc := aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: args.Options.AccessKey,
SecretAccessKey: string(args.Options.Secret),
}, nil
})
credProvider := aws_config.WithCredentialsProvider(credFunc)
configOptions = append(configOptions, credProvider)
}
cfg, err := aws_config.LoadDefaultConfig(context.TODO(), configOptions...)
if err != nil {
return nil, err
}
if args.Options.Region != "" {
cfg.Region = args.Options.Region
}
return &cfg, nil
}
func (s *S3) handleContentUpdate(ctx context.Context) {
for {
select {
case r := <-s.updateChan:
s.handleContentPolling(string(r.result), r.err)
case <-ctx.Done():
return
}
}
}
func (s *S3) handleContentPolling(newContent string, err error) {
s.mut.Lock()
defer s.mut.Unlock()
if err == nil {
s.opts.OnStateChange(Exports{
Content: rivertypes.OptionalSecret{
IsSecret: s.args.IsSecret,
Value: newContent,
},
})
s.lastAccessed.SetToCurrentTime()
s.content = newContent
s.health.Health = component.HealthTypeHealthy
s.health.Message = "s3 file updated"
} else {
s.s3Errors.Inc()
s.health.Health = component.HealthTypeUnhealthy
s.health.Message = err.Error()
}
s.health.UpdateTime = time.Now()
}
func getPathBucketAndFile(path string) (bucket, file string) {
parts := strings.Split(path, "/")
file = strings.Join(parts[3:], "/")
bucket = strings.Join(parts[:3], "/")
bucket = strings.ReplaceAll(bucket, "s3://", "")
return
}