package provider
import (
"errors"
"fmt"
"strings"
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/openapi"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/swagger"
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/http"
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/list"
"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
configTypes "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/retryablehttp-go"
"github.com/projectdiscovery/utils/errkit"
stringsutil "github.com/projectdiscovery/utils/strings"
)
var (
ErrNotImplemented = errkit.New("provider does not implement method")
ErrInactiveInput = fmt.Errorf("input is inactive")
)
const (
MultiFormatInputProvider = "MultiFormatInputProvider"
ListInputProvider = "ListInputProvider"
SimpleListInputProvider = "SimpleInputProvider"
)
func IsErrNotImplemented(err error) bool {
if err == nil {
return false
}
if stringsutil.ContainsAll(err.Error(), "provider", "does not implement") {
return true
}
return false
}
var (
_ InputProvider = &SimpleInputProvider{}
_ InputProvider = &http.HttpInputProvider{}
_ InputProvider = &list.ListInputProvider{}
)
type InputProvider interface {
Count() int64
Iterate(callback func(value *contextargs.MetaInput) bool)
Set(executionId string, value string)
SetWithProbe(executionId string, value string, probe types.InputLivenessProbe) error
SetWithExclusions(executionId string, value string) error
InputType() string
Close()
}
type InputOptions struct {
Options *configTypes.Options
TempDir string
NotFoundCallback func(template string) bool
}
func NewInputProvider(opts InputOptions) (InputProvider, error) {
val, err := formats.ReadOpenAPIVarDumpFile()
if err != nil && !errors.Is(err, formats.ErrNoVarsDumpFile) {
gologger.Error().Msgf("Could not read vars dump file: %s\n", err)
}
extraVars := make(map[string]interface{})
if val != nil {
for _, v := range val.Var {
v = strings.TrimSpace(v)
parts := strings.SplitN(v, "=", 2)
if len(parts) == 2 {
extraVars[parts[0]] = parts[1]
}
}
}
if strings.EqualFold(opts.Options.InputFileMode, "list") {
return list.New(&list.Options{
Options: opts.Options,
NotFoundCallback: opts.NotFoundCallback,
})
} else if len(opts.Options.Targets) > 0 &&
(strings.EqualFold(opts.Options.InputFileMode, "openapi") || strings.EqualFold(opts.Options.InputFileMode, "swagger")) {
if len(opts.Options.Targets) > 1 {
return nil, fmt.Errorf("only one target URL is supported in %s input mode", opts.Options.InputFileMode)
}
target := opts.Options.Targets[0]
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
var downloader formats.SpecDownloader
var tempFile string
var err error
var httpClient *retryablehttp.Client
if opts.Options.ExecutionId != "" {
dialers := protocolstate.GetDialersWithId(opts.Options.ExecutionId)
if dialers != nil {
httpClient = dialers.DefaultHTTPClient
}
}
switch strings.ToLower(opts.Options.InputFileMode) {
case "openapi":
downloader = openapi.NewDownloader()
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
case "swagger":
downloader = swagger.NewDownloader()
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
default:
return nil, fmt.Errorf("unsupported input mode: %s", opts.Options.InputFileMode)
}
if err != nil {
return nil, fmt.Errorf("failed to download %s spec from url %s: %w", opts.Options.InputFileMode, target, err)
}
opts.Options.TargetsFilePath = tempFile
}
}
return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{
InputFile: opts.Options.TargetsFilePath,
InputMode: opts.Options.InputFileMode,
Options: formats.InputFormatOptions{
Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()),
SkipFormatValidation: opts.Options.SkipFormatValidation,
RequiredOnly: opts.Options.FormatUseRequiredOnly,
VarsTextTemplating: opts.Options.VarsTextTemplating,
VarsFilePaths: opts.Options.VarsFilePaths,
},
})
}
func SupportedInputFormats() string {
return "list, " + http.SupportedFormats()
}