Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
projectdiscovery
GitHub Repository: projectdiscovery/nuclei
Path: blob/dev/pkg/js/utils/pgwrap/pgwrap.go
2070 views
1
package pgwrap
2
3
import (
4
"context"
5
"database/sql"
6
"database/sql/driver"
7
"fmt"
8
"net"
9
"net/url"
10
"time"
11
12
"github.com/lib/pq"
13
"github.com/projectdiscovery/fastdialer/fastdialer"
14
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
15
)
16
17
const (
18
PGWrapDriver = "pgwrap"
19
)
20
21
type pgDial struct {
22
executionId string
23
}
24
25
func (p *pgDial) Dial(network, address string) (net.Conn, error) {
26
dialers := protocolstate.GetDialersWithId(p.executionId)
27
if dialers == nil {
28
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
29
}
30
return dialers.Fastdialer.Dial(context.TODO(), network, address)
31
}
32
33
func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
34
dialers := protocolstate.GetDialersWithId(p.executionId)
35
if dialers == nil {
36
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
37
}
38
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout)
39
defer cancel()
40
return dialers.Fastdialer.Dial(ctx, network, address)
41
}
42
43
func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
44
dialers := protocolstate.GetDialersWithId(p.executionId)
45
if dialers == nil {
46
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
47
}
48
return dialers.Fastdialer.Dial(ctx, network, address)
49
}
50
51
// Unfortunately lib/pq does not provide easy to customize or
52
// replace dialer so we need to hijack it by wrapping it in our own
53
// driver and register it as postgres driver
54
55
// PgDriver is the Postgres database driver.
56
type PgDriver struct{}
57
58
// Open opens a new connection to the database. name is a connection string.
59
// Most users should only use it through database/sql package from the standard
60
// library.
61
func (d PgDriver) Open(name string) (driver.Conn, error) {
62
// Parse the connection string to get executionId
63
u, err := url.Parse(name)
64
if err != nil {
65
return nil, fmt.Errorf("invalid connection string: %v", err)
66
}
67
values := u.Query()
68
executionId := values.Get("executionId")
69
// Remove executionId from the connection string
70
values.Del("executionId")
71
u.RawQuery = values.Encode()
72
73
return pq.DialOpen(&pgDial{executionId: executionId}, u.String())
74
}
75
76
func init() {
77
sql.Register(PGWrapDriver, &PgDriver{})
78
}
79
80