package smb
import (
"context"
"fmt"
"time"
"github.com/praetorian-inc/fingerprintx/pkg/plugins"
"github.com/projectdiscovery/go-smb2"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/zmap/zgrab2/lib/smb/smb"
)
type (
SMBClient struct{}
)
func (c *SMBClient) ConnectSMBInfoMode(ctx context.Context, host string, port int) (*smb.SMBLog, error) {
executionId := ctx.Value("executionId").(string)
return memoizedconnectSMBInfoMode(executionId, host, port)
}
func connectSMBInfoMode(executionId string, host string, port int) (*smb.SMBLog, error) {
if !protocolstate.IsHostAllowed(executionId, host) {
return nil, protocolstate.ErrHostDenied.Msgf(host)
}
dialer := protocolstate.GetDialersWithId(executionId)
if dialer == nil {
return nil, fmt.Errorf("dialers not initialized for %s", executionId)
}
conn, err := dialer.Fastdialer.Dial(context.TODO(), "tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return nil, err
}
result, err := getSMBInfo(conn, true, false)
_ = conn.Close()
if err == nil {
return result, nil
}
conn, err = dialer.Fastdialer.Dial(context.TODO(), "tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return nil, err
}
defer func() {
_ = conn.Close()
}()
result, err = getSMBInfo(conn, true, true)
if err != nil {
return result, nil
}
return result, nil
}
func (c *SMBClient) ListSMBv2Metadata(ctx context.Context, host string, port int) (*plugins.ServiceSMB, error) {
executionId := ctx.Value("executionId").(string)
if !protocolstate.IsHostAllowed(executionId, host) {
return nil, protocolstate.ErrHostDenied.Msgf(host)
}
return memoizedcollectSMBv2Metadata(executionId, host, port, 5*time.Second)
}
func (c *SMBClient) ListShares(ctx context.Context, host string, port int, user, password string) ([]string, error) {
executionId := ctx.Value("executionId").(string)
return memoizedlistShares(executionId, host, port, user, password)
}
func listShares(executionId string, host string, port int, user string, password string) ([]string, error) {
if !protocolstate.IsHostAllowed(executionId, host) {
return nil, protocolstate.ErrHostDenied.Msgf(host)
}
dialer := protocolstate.GetDialersWithId(executionId)
if dialer == nil {
return nil, fmt.Errorf("dialers not initialized for %s", executionId)
}
conn, err := dialer.Fastdialer.Dial(context.TODO(), "tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return nil, err
}
defer func() {
_ = conn.Close()
}()
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: user,
Password: password,
},
}
s, err := d.Dial(conn)
if err != nil {
return nil, err
}
defer func() {
_ = s.Logoff()
}()
names, err := s.ListSharenames()
if err != nil {
return nil, err
}
return names, nil
}