package pgwrap
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"net/url"
"time"
"github.com/lib/pq"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
)
const (
PGWrapDriver = "pgwrap"
)
type pgDial struct {
executionId string
}
func (p *pgDial) Dial(network, address string) (net.Conn, error) {
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
}
return dialers.Fastdialer.Dial(context.TODO(), network, address)
}
func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
}
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout)
defer cancel()
return dialers.Fastdialer.Dial(ctx, network, address)
}
func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
}
return dialers.Fastdialer.Dial(ctx, network, address)
}
type PgDriver struct{}
func (d PgDriver) Open(name string) (driver.Conn, error) {
u, err := url.Parse(name)
if err != nil {
return nil, fmt.Errorf("invalid connection string: %v", err)
}
values := u.Query()
executionId := values.Get("executionId")
values.Del("executionId")
u.RawQuery = values.Encode()
return pq.DialOpen(&pgDial{executionId: executionId}, u.String())
}
func init() {
sql.Register(PGWrapDriver, &PgDriver{})
}