package sshutil
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io/fs"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/coreos/go-semver/semver"
"github.com/mattn/go-shellwords"
"github.com/sirupsen/logrus"
"golang.org/x/sys/cpu"
"github.com/lima-vm/lima/v2/pkg/ioutilx"
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
"github.com/lima-vm/lima/v2/pkg/lockutil"
"github.com/lima-vm/lima/v2/pkg/osutil"
)
const EnvShellSSH = "SSH"
type SSHExe struct {
Exe string
Args []string
}
func NewSSHExe() (SSHExe, error) {
var sshExe SSHExe
if sshShell := os.Getenv(EnvShellSSH); sshShell != "" {
sshShellFields, err := shellwords.Parse(sshShell)
switch {
case err != nil:
logrus.WithError(err).Warnf("Failed to split %s variable into shell tokens. "+
"Falling back to 'ssh' command", EnvShellSSH)
case len(sshShellFields) > 0:
sshExe.Exe = sshShellFields[0]
if len(sshShellFields) > 1 {
sshExe.Args = sshShellFields[1:]
}
return sshExe, nil
}
}
executable, err := exec.LookPath("ssh")
if err != nil {
return SSHExe{}, err
}
sshExe.Exe = executable
return sshExe, nil
}
type PubKey struct {
Filename string
Content string
}
func readPublicKey(f string) (PubKey, error) {
entry := PubKey{
Filename: f,
}
content, err := os.ReadFile(f)
if err == nil {
entry.Content = strings.TrimSpace(string(content))
} else {
err = fmt.Errorf("failed to read ssh public key %q: %w", f, err)
}
return entry, err
}
func DefaultPubKeys(ctx context.Context, loadDotSSH bool) ([]PubKey, error) {
configDir, err := dirnames.LimaConfigDir()
if err != nil {
return nil, err
}
_, err = os.Stat(filepath.Join(configDir, filenames.UserPrivateKey))
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
if err := os.MkdirAll(configDir, 0o700); err != nil {
return nil, fmt.Errorf("could not create %q directory: %w", configDir, err)
}
if err := lockutil.WithDirLock(configDir, func() error {
privPath := filepath.Join(configDir, filenames.UserPrivateKey)
if runtime.GOOS == "windows" {
privPath, err = ioutilx.WindowsSubsystemPath(ctx, privPath)
if err != nil {
return err
}
}
keygenCmd := exec.CommandContext(ctx, "ssh-keygen", "-t", "ed25519", "-q", "-N", "",
"-C", "lima", "-f", privPath)
logrus.Debugf("executing %v", keygenCmd.Args)
if out, err := keygenCmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to run %v: %q: %w", keygenCmd.Args, string(out), err)
}
return nil
}); err != nil {
return nil, err
}
}
entry, err := readPublicKey(filepath.Join(configDir, filenames.UserPublicKey))
if err != nil {
return nil, err
}
res := []PubKey{entry}
if !loadDotSSH {
return res, nil
}
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, err
}
files, err := filepath.Glob(filepath.Join(homeDir, ".ssh/*.pub"))
if err != nil {
panic(err)
}
for _, f := range files {
if !strings.HasSuffix(f, ".pub") {
panic(fmt.Errorf("unexpected ssh public key filename %q", f))
}
entry, err := readPublicKey(f)
if err == nil {
if !detectValidPublicKey(entry.Content) {
logrus.Warnf("public key %q doesn't seem to be in ssh format", entry.Filename)
} else {
res = append(res, entry)
}
} else if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
}
return res, nil
}
type openSSHInfo struct {
Version semver.Version
GSSAPISupported bool
}
var sshInfo struct {
sync.Once
aesAccelerated bool
openSSH openSSHInfo
}
func CommonOpts(ctx context.Context, sshExe SSHExe, useDotSSH bool) ([]string, error) {
configDir, err := dirnames.LimaConfigDir()
if err != nil {
return nil, err
}
privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey)
_, err = os.Stat(privateKeyPath)
if err != nil {
return nil, err
}
var opts []string
idf, err := identityFileEntry(ctx, privateKeyPath)
if err != nil {
return nil, err
}
opts = []string{idf}
if useDotSSH {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, err
}
files, err := filepath.Glob(filepath.Join(homeDir, ".ssh/*.pub"))
if err != nil {
panic(err)
}
for _, f := range files {
if !strings.HasSuffix(f, ".pub") {
panic(fmt.Errorf("unexpected ssh public key filename %q", f))
}
privateKeyPath := strings.TrimSuffix(f, ".pub")
_, err = os.Stat(privateKeyPath)
if errors.Is(err, fs.ErrNotExist) {
continue
}
if err != nil {
return nil, err
}
idf, err = identityFileEntry(ctx, privateKeyPath)
if err != nil {
return nil, err
}
opts = append(opts, idf)
}
}
opts = append(opts,
"StrictHostKeyChecking=no",
"UserKnownHostsFile=/dev/null",
"NoHostAuthenticationForLocalhost=yes",
"PreferredAuthentications=publickey",
"Compression=no",
"BatchMode=yes",
"IdentitiesOnly=yes",
)
sshInfo.Do(func() {
sshInfo.aesAccelerated = detectAESAcceleration()
sshInfo.openSSH = detectOpenSSHInfo(ctx, sshExe)
})
if sshInfo.openSSH.GSSAPISupported {
opts = append(opts, "GSSAPIAuthentication=no")
}
if !sshInfo.openSSH.Version.LessThan(*semver.New("8.1.0")) {
if sshInfo.aesAccelerated {
logrus.Debugf("AES accelerator seems available, prioritizing [email protected] and [email protected]")
if runtime.GOOS == "windows" {
opts = append(opts, "Ciphers=^[email protected],[email protected]")
} else {
opts = append(opts, "Ciphers=\"^[email protected],[email protected]\"")
}
} else {
logrus.Debugf("AES accelerator does not seem available, prioritizing [email protected]")
if runtime.GOOS == "windows" {
opts = append(opts, "Ciphers=^[email protected]")
} else {
opts = append(opts, "Ciphers=\"^[email protected]\"")
}
}
}
return opts, nil
}
func identityFileEntry(ctx context.Context, privateKeyPath string) (string, error) {
if runtime.GOOS == "windows" {
privateKeyPath, err := ioutilx.WindowsSubsystemPath(ctx, privateKeyPath)
if err != nil {
return "", err
}
return fmt.Sprintf(`IdentityFile='%s'`, privateKeyPath), nil
}
return fmt.Sprintf(`IdentityFile="%s"`, privateKeyPath), nil
}
func DisableControlMasterOptsFromSSHArgs(sshArgs []string) []string {
argsForOverridingConfigFile := []string{
"-o", "ControlMaster=no",
"-o", "ControlPath=none",
"-o", "ControlPersist=no",
}
return slices.Concat(argsForOverridingConfigFile, removeOptsFromSSHArgs(sshArgs, "ControlMaster", "ControlPath", "ControlPersist"))
}
func removeOptsFromSSHArgs(sshArgs []string, removeOpts ...string) []string {
res := make([]string, 0, len(sshArgs))
isOpt := false
for _, arg := range sshArgs {
if isOpt {
isOpt = false
if !slices.ContainsFunc(removeOpts, func(opt string) bool {
return strings.HasPrefix(arg, opt)
}) {
res = append(res, "-o", arg)
}
} else if arg == "-o" {
isOpt = true
} else {
res = append(res, arg)
}
}
return res
}
func IsControlMasterExisting(instDir string) bool {
controlSock := filepath.Join(instDir, filenames.SSHSock)
_, err := os.Stat(controlSock)
return err == nil
}
func SSHOpts(ctx context.Context, sshExe SSHExe, instDir, username string, useDotSSH, forwardAgent, forwardX11, forwardX11Trusted bool) ([]string, error) {
controlSock := filepath.Join(instDir, filenames.SSHSock)
if len(controlSock) >= osutil.UnixPathMax {
return nil, fmt.Errorf("socket path %q is too long: >= UNIX_PATH_MAX=%d", controlSock, osutil.UnixPathMax)
}
opts, err := CommonOpts(ctx, sshExe, useDotSSH)
if err != nil {
return nil, err
}
controlPath := fmt.Sprintf(`ControlPath="%s"`, controlSock)
if runtime.GOOS == "windows" {
controlSock, err = ioutilx.WindowsSubsystemPath(ctx, controlSock)
if err != nil {
return nil, err
}
controlPath = fmt.Sprintf(`ControlPath='%s'`, controlSock)
}
opts = append(opts,
fmt.Sprintf("User=%s", username),
"ControlMaster=auto",
controlPath,
"ControlPersist=yes",
)
if forwardAgent {
opts = append(opts, "ForwardAgent=yes")
}
if forwardX11 {
opts = append(opts, "ForwardX11=yes")
}
if forwardX11Trusted {
opts = append(opts, "ForwardX11Trusted=yes")
}
return opts, nil
}
func SSHArgsFromOpts(opts []string) []string {
args := []string{"-F", "/dev/null"}
for _, o := range opts {
args = append(args, "-o", o)
}
return args
}
func SSHOptsRemovingControlPath(opts []string) []string {
copiedOpts := slices.Clone(opts)
return slices.DeleteFunc(copiedOpts, func(s string) bool {
return strings.HasPrefix(s, "ControlMaster") || strings.HasPrefix(s, "ControlPath") || strings.HasPrefix(s, "ControlPersist")
})
}
func ParseOpenSSHVersion(version []byte) *semver.Version {
regex := regexp.MustCompile(`(?m)^OpenSSH_(\d+\.\d+)(?:p(\d+))?\b`)
matches := regex.FindSubmatch(version)
if len(matches) == 3 {
if len(matches[2]) == 0 {
matches[2] = []byte("0")
}
return semver.New(fmt.Sprintf("%s.%s", matches[1], matches[2]))
}
return &semver.Version{}
}
func parseOpenSSHGSSAPISupported(version string) bool {
return !strings.Contains(version, `Unsupported option "gssapiauthentication"`)
}
type sshExecutable struct {
Path string
Size int64
ModTime time.Time
}
var (
openSSHInfos = map[sshExecutable]*openSSHInfo{}
openSSHInfosRW sync.RWMutex
)
func detectOpenSSHInfo(ctx context.Context, sshExe SSHExe) openSSHInfo {
var (
info openSSHInfo
exe sshExecutable
stderr bytes.Buffer
)
if st, err := os.Stat(sshExe.Exe); err == nil {
exe = sshExecutable{Path: sshExe.Exe, Size: st.Size(), ModTime: st.ModTime()}
openSSHInfosRW.RLock()
info := openSSHInfos[exe]
openSSHInfosRW.RUnlock()
if info != nil {
return *info
}
}
sshArgs := append([]string{}, sshExe.Args...)
sshArgs = append(sshArgs, "-o", "GSSAPIAuthentication=no", "-V")
cmd := exec.CommandContext(ctx, sshExe.Exe, sshArgs...)
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
logrus.Warnf("failed to run %v: stderr=%q", cmd.Args, stderr.String())
} else {
info = openSSHInfo{
Version: *ParseOpenSSHVersion(stderr.Bytes()),
GSSAPISupported: parseOpenSSHGSSAPISupported(stderr.String()),
}
logrus.Debugf("OpenSSH version %s detected, is GSSAPI supported: %t", info.Version, info.GSSAPISupported)
openSSHInfosRW.Lock()
openSSHInfos[exe] = &info
openSSHInfosRW.Unlock()
}
return info
}
func DetectOpenSSHVersion(ctx context.Context, sshExe SSHExe) semver.Version {
return detectOpenSSHInfo(ctx, sshExe).Version
}
func detectValidPublicKey(content string) bool {
if strings.ContainsRune(content, '\n') {
return false
}
spaced := strings.SplitN(content, " ", 3)
if len(spaced) < 2 {
return false
}
algo, base64Key := spaced[0], spaced[1]
decodedKey, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil || len(decodedKey) < 4 {
return false
}
sigLength := binary.BigEndian.Uint32(decodedKey)
if uint32(len(decodedKey)) < sigLength {
return false
}
sigFormat := string(decodedKey[4 : 4+sigLength])
return algo == sigFormat
}
func detectAESAcceleration() bool {
if !cpu.Initialized {
if runtime.GOOS == "linux" && runtime.GOARCH == "arm64" {
return cpu.ARM64.HasAES
}
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
logrus.Debug("Failed to detect CPU features. Assuming that AES acceleration is available on this Apple silicon.")
return true
}
logrus.Warn("Failed to detect CPU features. Assuming that AES acceleration is not available.")
return false
}
return cpu.ARM.HasAES || cpu.ARM64.HasAES || cpu.PPC64.IsPOWER8 || cpu.S390X.HasAES || cpu.X86.HasAES
}