package aws
import (
"io"
"reflect"
"slices"
"strings"
"testing"
"github.com/pkg/errors"
)
func TestCatalog_GetTemplatePath(t *testing.T) {
type args struct {
target string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
"get all ssl files",
args{
target: "ssl",
},
[]string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
},
false,
},
{
"get all ssl files with wildcard",
args{
target: "ssl*",
},
[]string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
},
false,
},
{
"non-matching target",
args{
target: "I-DONT-EXIST",
},
[]string{},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.GetTemplatePath(tt.args.target)
if (err != nil) != tt.wantErr {
t.Errorf("GetTemplatePath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(tt.want) > 0 && !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetTemplatePath() got = %v, want %v", got, tt.want)
}
if len(tt.want) == 0 && len(got) > 0 {
t.Errorf("GetTemplatePath() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCatalog_GetTemplatesPath(t *testing.T) {
tmp := newMockS3Service()
keys, _ := tmp.getAllKeys()
type args struct {
definitions []string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
"without definitions",
args{
definitions: nil,
},
keys,
false,
},
{
"with definitions",
args{
definitions: []string{"ssl/deprecated-tls.yaml"},
},
keys,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, got1 := c.GetTemplatesPath(tt.args.definitions)
if got1 != nil {
val, exists := got1["aws"]
if exists && !tt.wantErr {
t.Errorf("GetTemplatesPath() error = %v, wantErr %v", val, tt.wantErr)
}
if !exists && len(got1) > 0 {
t.Errorf("GetTemplatesPath() should only return one key 'aws': %v", got1)
}
if !exists && tt.wantErr {
t.Errorf("GetTemplatesPath() error = %v, wantErr %v", val, tt.wantErr)
}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetTemplatesPath() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCatalog_OpenFile(t *testing.T) {
tests := []struct {
name string
filename string
wantErr bool
}{
{
"valid key",
"ssl/deprecated-tls.yaml",
false,
},
{
"nonexistent key",
"something/that-doesnt-exist.yaml",
true,
},
{
"path to folder",
"cves/2023",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.OpenFile(tt.filename)
if (err != nil) != tt.wantErr {
t.Errorf("OpenFile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && got == nil {
t.Error("OpenFile() didn't return error but io.ReadCloser is nil")
}
})
}
}
func TestCatalog_ResolvePath(t *testing.T) {
type args struct {
templateName string
second string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
"absolute path",
args{
"ssl/deprecated-tls.yaml",
"",
},
"ssl/deprecated-tls.yaml",
false,
},
{
"relative path with second param",
args{
"deprecated-tls.yaml",
"ssl/",
},
"ssl/deprecated-tls.yaml",
false,
},
{
"relative path and no second param",
args{
"cves/2023",
"",
},
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.ResolvePath(tt.args.templateName, tt.args.second)
if (err != nil) != tt.wantErr {
t.Errorf("ResolvePath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ResolvePath() got = %v, want %v", got, tt.want)
}
})
}
}
func withMockS3Service() func(*Catalog) error {
return func(c *Catalog) error {
c.svc = newMockS3Service()
return nil
}
}
type mocks3svc struct {
keys []string
}
func newMockS3Service() mocks3svc {
return mocks3svc{
keys: []string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
"cves/2023/CVE-2023-0669.yaml",
"cves/2023/CVE-2023-23488.yaml",
"cves/2023/CVE-2023-23489.yaml",
},
}
}
func (m mocks3svc) getAllKeys() ([]string, error) {
return m.keys, nil
}
func (m mocks3svc) downloadKey(name string) (io.ReadCloser, error) {
found := slices.Contains(m.keys, name)
if !found {
return nil, errors.New("key not found")
}
sample := `
id: git-config
info:
name: Git Config File
author: Ice3man
severity: medium
description: Searches for the pattern /.git/config on passed URLs.
requests:
- method: GET
path:
- "{{BaseURL}}/.git/config"
matchers:
- type: word
words:
- "[core]"
`
return io.NopCloser(strings.NewReader(sample)), nil
}
func (m mocks3svc) setBucket(bucket string) {}