Path: blob/dev/pkg/protocols/http/httputils/spm.go
2072 views
package httputils12import (3"context"4"sync"56syncutil "github.com/projectdiscovery/utils/sync"7"golang.org/x/exp/maps"8)910// WorkPoolType is the type of work pool to use11type WorkPoolType uint1213const (14// Blocking blocks addition of new work when the pool is full15Blocking WorkPoolType = iota16// NonBlocking does not block addition of new work when the pool is full17NonBlocking18)1920// StopAtFirstMatchHandler is a handler that executes21// request and stops on first match22type StopAtFirstMatchHandler[T comparable] struct {23once sync.Once24// Result Channel25ResultChan chan T2627// work pool and its type28poolType WorkPoolType29sgPool *syncutil.AdaptiveWaitGroup30wgPool *sync.WaitGroup3132// internal / unexported33ctx context.Context34cancel context.CancelFunc35internalWg *sync.WaitGroup36results map[T]struct{}37onResult func(T)38stopEnabled bool39maxResults int40}4142// NewBlockingSPMHandler creates a new stop at first match handler43func NewBlockingSPMHandler[T comparable](ctx context.Context, size int, maxResults int, spm bool) *StopAtFirstMatchHandler[T] {44ctx1, cancel := context.WithCancel(ctx)4546awg, _ := syncutil.New(syncutil.WithSize(size))4748s := &StopAtFirstMatchHandler[T]{49ResultChan: make(chan T, 1),50poolType: Blocking,51sgPool: awg,52internalWg: &sync.WaitGroup{},53ctx: ctx1,54cancel: cancel,55stopEnabled: spm,56results: make(map[T]struct{}),57maxResults: maxResults,58}59s.internalWg.Add(1)60go s.run(ctx)61return s62}6364// NewNonBlockingSPMHandler creates a new stop at first match handler65func NewNonBlockingSPMHandler[T comparable](ctx context.Context, maxResults int, spm bool) *StopAtFirstMatchHandler[T] {66ctx1, cancel := context.WithCancel(ctx)67s := &StopAtFirstMatchHandler[T]{68ResultChan: make(chan T, 1),69poolType: NonBlocking,70wgPool: &sync.WaitGroup{},71internalWg: &sync.WaitGroup{},72ctx: ctx1,73cancel: cancel,74stopEnabled: spm,75results: make(map[T]struct{}),76maxResults: maxResults,77}78s.internalWg.Add(1)79go s.run(ctx)80return s81}8283// Trigger triggers the stop at first match handler and stops the execution of84// existing requests85func (h *StopAtFirstMatchHandler[T]) Trigger() {86if h.stopEnabled {87h.cancel()88}89}9091// Cancel cancels spm context92func (h *StopAtFirstMatchHandler[T]) Cancel() {93h.cancel()94}9596// SetOnResult callback97// this is not thread safe98func (h *StopAtFirstMatchHandler[T]) SetOnResultCallback(fn func(T)) {99if h.onResult != nil {100tmp := h.onResult101h.onResult = func(t T) {102tmp(t)103fn(t)104}105} else {106h.onResult = fn107}108}109110// MatchCallback is called when a match is found111// input fn should be the callback that is intended to be called112// if stop at first is enabled and other conditions are met113// if it does not meet above conditions, use of this function is discouraged114func (h *StopAtFirstMatchHandler[T]) MatchCallback(fn func()) {115if !h.stopEnabled {116fn()117return118}119h.once.Do(fn)120}121122// run runs the internal handler123func (h *StopAtFirstMatchHandler[T]) run(ctx context.Context) {124defer h.internalWg.Done()125126for {127select {128case <-ctx.Done():129return130case val, ok := <-h.ResultChan:131if !ok {132return133}134if h.onResult != nil {135h.onResult(val)136}137if len(h.results) >= h.maxResults {138// skip or do not store the result139continue140}141h.results[val] = struct{}{}142}143}144}145146// Done returns a channel with the context done signal when stop at first match is detected147func (h *StopAtFirstMatchHandler[T]) Done() <-chan struct{} {148return h.ctx.Done()149}150151// Cancelled returns true if the context is cancelled152func (h *StopAtFirstMatchHandler[T]) Cancelled() bool {153return h.ctx.Err() != nil154}155156// FoundFirstMatch returns true if first match was found157// in stop at first match mode158func (h *StopAtFirstMatchHandler[T]) FoundFirstMatch() bool {159if h.ctx.Err() != nil && h.stopEnabled {160return true161}162return false163}164165// Acquire acquires a new work166func (h *StopAtFirstMatchHandler[T]) Acquire() {167switch h.poolType {168case Blocking:169h.sgPool.Add()170case NonBlocking:171h.wgPool.Add(1)172}173}174175// Release releases a work176func (h *StopAtFirstMatchHandler[T]) Release() {177switch h.poolType {178case Blocking:179h.sgPool.Done()180case NonBlocking:181h.wgPool.Done()182}183}184185func (h *StopAtFirstMatchHandler[T]) Resize(ctx context.Context, size int) error {186if h.sgPool.Size != size {187return h.sgPool.Resize(ctx, size)188}189return nil190}191192func (h *StopAtFirstMatchHandler[T]) Size() int {193return h.sgPool.Size194}195196// Wait waits for all work to be done197func (h *StopAtFirstMatchHandler[T]) Wait() {198switch h.poolType {199case Blocking:200h.sgPool.Wait()201case NonBlocking:202h.wgPool.Wait()203}204// after waiting it closes the error channel205close(h.ResultChan)206h.internalWg.Wait()207}208209// CombinedResults returns the combined results210func (h *StopAtFirstMatchHandler[T]) CombinedResults() []T {211return maps.Keys(h.results)212}213214215