package telnetmini
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
)
const (
IAC = 255
WILL = 251
WONT = 252
DO = 253
DONT = 254
SB = 250
SE = 240
ENCRYPT = 38
)
type EncryptionInfo struct {
SupportsEncryption bool
Banner string
Options map[int][]int
}
type Client struct {
Conn net.Conn
rd *bufio.Reader
wr *bufio.Writer
LoginPrompts []string
UserPrompts []string
PasswordPrompts []string
FailBanners []string
ShellPrompts []string
ReadCapBytes int
}
func (c *Client) Defaults() {
if c.ReadCapBytes == 0 {
c.ReadCapBytes = 64 * 1024
}
if len(c.LoginPrompts) == 0 {
c.LoginPrompts = []string{"login:", "username:"}
}
if len(c.PasswordPrompts) == 0 {
c.PasswordPrompts = []string{"password:"}
}
if len(c.FailBanners) == 0 {
c.FailBanners = []string{"login incorrect", "authentication failed", "login failed"}
}
if len(c.ShellPrompts) == 0 {
c.ShellPrompts = []string{"$ ", "# ", "> "}
}
if len(c.UserPrompts) == 0 {
c.UserPrompts = c.LoginPrompts
}
}
func New(conn net.Conn) *Client {
c := &Client{
Conn: conn,
rd: bufio.NewReader(conn),
wr: bufio.NewWriter(conn),
}
c.Defaults()
return c
}
func (c *Client) Close() error {
return c.Conn.Close()
}
func DetectEncryption(conn net.Conn, timeout time.Duration) (*EncryptionInfo, error) {
if timeout == 0 {
timeout = 7 * time.Second
}
_ = conn.SetDeadline(time.Now().Add(timeout))
encryptionPacket := []byte{IAC, DO, ENCRYPT, IAC, WILL, ENCRYPT}
_, err := conn.Write(encryptionPacket)
if err != nil {
return nil, fmt.Errorf("failed to send encryption packet: %w", err)
}
options := make(map[int][]int)
supportsEncryption := false
banner := ""
for {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
break
}
if n > 0 {
data := buffer[:n]
for _, b := range data {
if b != IAC {
banner += string(b)
}
}
encrypted, opts := processTelnetOptions(data)
if encrypted {
supportsEncryption = true
}
for opt, cmds := range opts {
if options[opt] == nil {
options[opt] = make([]int, 0)
}
options[opt] = append(options[opt], cmds...)
}
if cmds, exists := options[ENCRYPT]; exists {
for _, cmd := range cmds {
if cmd == WILL || cmd == DO {
supportsEncryption = true
break
}
}
}
}
}
return &EncryptionInfo{
SupportsEncryption: supportsEncryption,
Banner: banner,
Options: options,
}, nil
}
func processTelnetOptions(data []byte) (bool, map[int][]int) {
options := make(map[int][]int)
supportsEncryption := false
for i := 0; i < len(data); i++ {
if data[i] == IAC && i+2 < len(data) {
cmd := data[i+1]
option := data[i+2]
optInt := int(option)
if options[optInt] == nil {
options[optInt] = make([]int, 0)
}
options[optInt] = append(options[optInt], int(cmd))
if option == ENCRYPT && (cmd == WILL || cmd == DO) {
supportsEncryption = true
}
if cmd == SB {
for j := i + 3; j < len(data); j++ {
if data[j] == IAC && j+1 < len(data) && data[j+1] == SE {
i = j + 1
break
}
}
} else {
i += 2
}
}
}
return supportsEncryption, options
}
func (c *Client) Auth(ctx context.Context, username, password string) error {
if _, _, err := c.readUntil(ctx, c.UserPrompts...); err != nil {
return fmt.Errorf("waiting for login/username prompt: %w", err)
}
if err := c.writeLine(ctx, username); err != nil {
return fmt.Errorf("sending username: %w", err)
}
if _, _, err := c.readUntil(ctx, c.PasswordPrompts...); err != nil {
return fmt.Errorf("waiting for password prompt: %w", err)
}
if err := c.writeLine(ctx, password); err != nil {
return fmt.Errorf("sending password: %w", err)
}
match, got, err := c.readUntil(ctx,
append(append([]string{}, c.FailBanners...), c.ShellPrompts...)...,
)
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("post-auth read: %s (got: %s)", preview(got, 200), err)
}
low := strings.ToLower(match)
for _, fb := range c.FailBanners {
if low == strings.ToLower(fb) {
return errors.New("authentication failed")
}
}
return nil
}
func (c *Client) Exec(ctx context.Context, command string, until ...string) (string, error) {
if err := c.writeLine(ctx, command); err != nil {
return "", err
}
_, out, err := c.readUntil(ctx, until...)
return out, err
}
func (c *Client) writeLine(ctx context.Context, s string) error {
c.setDeadlineFromCtx(ctx, true)
if _, err := io.WriteString(c.wr, s+"\r\n"); err != nil {
return err
}
return c.wr.Flush()
}
func (c *Client) readUntil(ctx context.Context, needles ...string) (matched string, bufStr string, err error) {
if len(needles) == 0 {
return "", "", errors.New("readUntil: no needles provided")
}
c.setDeadlineFromCtx(ctx, false)
lowNeedles := make([]string, len(needles))
for i, n := range needles {
lowNeedles[i] = strings.ToLower(n)
}
var b strings.Builder
tmp := make([]byte, 1)
maxIterations := 20
iterationCount := 0
for {
iterationCount++
if iterationCount > maxIterations {
return "", b.String(), nil
}
c.setDeadlineFromCtx(ctx, false)
_, err := c.rd.Read(tmp)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return "", b.String(), context.DeadlineExceeded
}
return "", b.String(), err
}
if tmp[0] == 255 {
cmd, err := c.rd.ReadByte()
if err != nil {
return "", b.String(), err
}
switch cmd {
case 251, 252, 253, 254:
opt, err := c.rd.ReadByte()
if err != nil {
return "", b.String(), err
}
var reply []byte
if cmd == 251 {
reply = []byte{255, 254, opt}
}
if cmd == 253 {
reply = []byte{255, 252, opt}
}
if len(reply) > 0 {
c.setDeadlineFromCtx(ctx, true)
_, _ = c.wr.Write(reply)
_ = c.wr.Flush()
}
case 250:
for {
bb, err := c.rd.ReadByte()
if err != nil {
return "", b.String(), err
}
if bb == 255 {
if se, err := c.rd.ReadByte(); err == nil && se == 240 {
break
}
}
}
default:
}
continue
}
b.WriteByte(tmp[0])
lower := strings.ToLower(b.String())
for i, n := range lowNeedles {
if strings.Contains(lower, n) {
return needles[i], b.String(), nil
}
}
if b.Len() > c.ReadCapBytes {
return "", b.String(), errors.New("prompt not found (read cap reached)")
}
}
}
func (c *Client) setDeadlineFromCtx(ctx context.Context, write bool) {
if ctx == nil {
return
}
if dl, ok := ctx.Deadline(); ok {
_ = c.Conn.SetReadDeadline(dl)
if write {
_ = c.Conn.SetWriteDeadline(dl)
}
}
}
func preview(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}