package limatmpl
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"unicode"
"github.com/coreos/go-semver/semver"
"github.com/sirupsen/logrus"
"github.com/lima-vm/lima/v2/pkg/limatype"
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
"github.com/lima-vm/lima/v2/pkg/limayaml"
"github.com/lima-vm/lima/v2/pkg/version/versionutil"
"github.com/lima-vm/lima/v2/pkg/yqutil"
)
func (tmpl *Template) Embed(ctx context.Context, embedAll, defaultBase bool) error {
if err := tmpl.UseAbsLocators(); err != nil {
return err
}
seen := make(map[string]bool)
err := tmpl.embedAllBases(ctx, embedAll, defaultBase, seen)
if err == nil {
err = tmpl.combineListEntries()
}
return tmpl.ClearOnError(err)
}
func (tmpl *Template) embedAllBases(ctx context.Context, embedAll, defaultBase bool, seen map[string]bool) error {
logrus.Debugf("Embedding templates into %q", tmpl.Locator)
if defaultBase {
configDir, err := dirnames.LimaConfigDir()
if err != nil {
return err
}
defaultBaseFilename := filepath.Join(configDir, filenames.Base)
if _, err := os.Stat(defaultBaseFilename); err == nil {
tmpl.expr.WriteString("| ($a.base | select(type == \"!!str\")) |= [\"\" + .]\n")
tmpl.expr.WriteString("| ($a.base | select(type == \"!!map\")) |= [[] + .]\n")
tmpl.expr.WriteString(fmt.Sprintf("| $a.base = [%q, $a.base[]]\n", defaultBaseFilename))
if err := tmpl.evalExpr(); err != nil {
return err
}
}
}
for {
if err := tmpl.Unmarshal(); err != nil {
return err
}
if len(tmpl.Config.Base) == 0 {
break
}
baseLocator := tmpl.Config.Base[0]
if baseLocator.Digest != nil {
return fmt.Errorf("base %q in %q has specified a digest; digest support is not yet implemented", baseLocator.URL, tmpl.Locator)
}
isTemplate, _ := SeemsTemplateURL(baseLocator.URL)
if isTemplate && !embedAll {
for i := 1; i < len(tmpl.Config.Base); i++ {
isTemplate, _ = SeemsTemplateURL(tmpl.Config.Base[i].URL)
if !isTemplate {
return fmt.Errorf("cannot embed template %q after not embedding %q", tmpl.Config.Base[i].URL, baseLocator.URL)
}
}
break
}
if seen[baseLocator.URL] {
return fmt.Errorf("base template loop detected: template %q already included", baseLocator.URL)
}
seen[baseLocator.URL] = true
if err := tmpl.embedBase(ctx, baseLocator, embedAll, seen); err != nil {
return err
}
}
if err := tmpl.embedAllScripts(ctx, embedAll); err != nil {
return err
}
if len(tmpl.Bytes) > yBytesLimit {
return fmt.Errorf("template %q embedding exceeded the size limit (%d bytes)", tmpl.Locator, yBytesLimit)
}
return nil
}
func (tmpl *Template) embedBase(ctx context.Context, baseLocator limatype.LocatorWithDigest, embedAll bool, seen map[string]bool) error {
logrus.Debugf("Embedding base %q in template %q", baseLocator.URL, tmpl.Locator)
if err := tmpl.Unmarshal(); err != nil {
return err
}
base, err := Read(ctx, "", baseLocator.URL)
if err != nil {
return err
}
if err := base.UseAbsLocators(); err != nil {
return err
}
if err := base.embedAllBases(ctx, embedAll, false, seen); err != nil {
return err
}
if err := tmpl.merge(base); err != nil {
return err
}
if len(tmpl.Bytes) > yBytesLimit {
return fmt.Errorf("template %q embedding exceeded the size limit (%d bytes)", tmpl.Locator, yBytesLimit)
}
return nil
}
func (tmpl *Template) evalExprImpl(prefix string, b []byte) error {
var err error
expr := prefix + tmpl.expr.String() + "| $a"
tmpl.Bytes, err = yqutil.EvaluateExpression(expr, b)
tmpl.Bytes = append(bytes.TrimRight(tmpl.Bytes, "\n"), '\n')
tmpl.Config = nil
tmpl.expr.Reset()
return tmpl.ClearOnError(err)
}
func (tmpl *Template) evalExpr() error {
var err error
if tmpl.expr.Len() > 0 {
singleDocument := "select(document_index == 0) as $a | $a as $b\n"
err = tmpl.evalExprImpl(singleDocument, tmpl.Bytes)
}
return err
}
func (tmpl *Template) merge(base *Template) error {
if err := tmpl.mergeBase(base); err != nil {
return tmpl.ClearOnError(err)
}
documents := fmt.Sprintf("%s\n---\n%s", string(tmpl.Bytes), string(base.Bytes))
return tmpl.evalExprImpl(mergeDocuments, []byte(documents))
}
func (tmpl *Template) mergeBase(base *Template) error {
if err := tmpl.Unmarshal(); err != nil {
return err
}
if err := base.Unmarshal(); err != nil {
return err
}
if tmpl.Config.MinimumLimaVersion != nil && base.Config.MinimumLimaVersion != nil {
if versionutil.GreaterThan(*base.Config.MinimumLimaVersion, *tmpl.Config.MinimumLimaVersion) {
const minimumLimaVersion = "minimumLimaVersion"
tmpl.copyField(minimumLimaVersion, minimumLimaVersion)
}
}
var tmplOpts limatype.QEMUOpts
if err := limayaml.Convert(tmpl.Config.VMOpts[limatype.QEMU], &tmplOpts, "vmOpts.qemu"); err != nil {
return err
}
var baseOpts limatype.QEMUOpts
if err := limayaml.Convert(base.Config.VMOpts[limatype.QEMU], &baseOpts, "vmOpts.qemu"); err != nil {
return err
}
if tmplOpts.MinimumVersion != nil && baseOpts.MinimumVersion != nil {
tmplVersion := *semver.New(*tmplOpts.MinimumVersion)
baseVersion := *semver.New(*baseOpts.MinimumVersion)
if tmplVersion.LessThan(baseVersion) {
const minimumQEMUVersion = "vmOpts.qemu.minimumVersion"
tmpl.copyField(minimumQEMUVersion, minimumQEMUVersion)
}
}
return nil
}
const mergeDocuments = `
select(document_index == 0) as $a
| select(document_index == 1) as $b
# $c will be mutilated to implement our own "merge only new fields" logic.
| $b as $c
# Delete the base that is being merged right now
| $a | select(.base | tag == "!!seq") | del(.base[0])
| $a | select(.base | (tag == "!!seq" and length == 0)) | del(.base)
| $a | select(.base | tag == "!!str") | del(.base)
# If $a.base is a list, then $b.base must be a list as well
# (note $b, not $c, because we merge lists from $b)
| $b | select((.base | tag == "!!str") and ($a.base | tag == "!!seq")) | .base = [ "" + .base ]
# Delete base DNS entries if the template list is not empty.
| $a | select(.dns) | del($b.dns, $c.dns)
# Mark all new list fields with a custom tag. This is needed to avoid appending
# newly copied lists to themselves again when we merge lists.
| $c | .. | select(tag == "!!seq") tag = "!!tag"
# Delete all nodes in $c that are in $a and not a map. This is necessary because
# the yq "*n" operator (merge only new fields) does not copy all comments across.
| $c | delpaths([$a | .. | select(tag != "!!map") | path])
# Merging with null returns null; use an empty map if $c has only comments
| $a * ($c // {}) as $a
# Find all elements that are existing lists. This will not match newly
# copied lists because they have a custom !!tag instead of !!seq.
# Append the elements from the same path in $b.
# Exception: base templates, provision scripts and probes are prepended instead.
| $a | (.. | select(tag == "!!seq" and (path[0] | test("^(base|provision|probes)$") | not))) |=
(. + (path[] as $p ireduce ($b; .[$p])))
| $a | (.. | select(tag == "!!seq" and (path[0] | test("^(base|provision|probes)$")))) |=
((path[] as $p ireduce ($b; .[$p])) + .)
# Copy head and line comments for existing lists that do not already have comments.
# New lists and existing maps already have their comments updated by the $a * $c merge.
| $a | (.. | select(tag == "!!seq" and (key | head_comment == "")) | key) head_comment |=
(((path[] as $p ireduce ($b; .[$p])) | key | head_comment) // "")
| $a | (.. | select(tag == "!!seq" and (key | line_comment == "")) | key) line_comment |=
(((path[] as $p ireduce ($b; .[$p])) | key | line_comment) // "")
# Make sure mountTypesUnsupported elements are unique.
| $a | (select(.mountTypesUnsupported) | .mountTypesUnsupported) |= unique
# Remove the custom tags again so they do not clutter up the YAML output.
| $a | .. | select(tag == "!!tag") tag = ""
`
func listFields(list string, dstIdx, srcIdx int, field string) (dst, src string) {
dst = fmt.Sprintf("%s[%d]", list, dstIdx)
src = fmt.Sprintf("%s[%d]", list, srcIdx)
if field != "" {
dst += "." + field
src += "." + field
}
return dst, src
}
func (tmpl *Template) copyField(dst, src string) {
tmpl.expr.WriteString(fmt.Sprintf("| ($a.%s) = $b.%s\n", dst, src))
tmpl.expr.WriteString(fmt.Sprintf("| ($a.%s | key) head_comment = ($b.%s | key | head_comment)\n", dst, src))
}
func (tmpl *Template) copyListEntryField(list string, dstIdx, srcIdx int, field string) {
tmpl.copyField(listFields(list, dstIdx, srcIdx, field))
}
type commentType string
const (
headComment commentType = "head"
lineComment commentType = "line"
)
func (tmpl *Template) copyComment(dst, src string, commentType commentType, isMapElement bool) {
onKey := ""
if isMapElement {
onKey = " | key"
}
tmpl.expr.WriteString(fmt.Sprintf("| $a | (select(.%s) | .%s%s | select(%s_comment == \"\" and ($b.%s%s | %s_comment != \"\"))) %s_comment |= ($b.%s%s | %s_comment)\n",
dst, dst, onKey, commentType, src, onKey, commentType, commentType, src, onKey, commentType))
}
func (tmpl *Template) copyComments(dst, src string, isMapElement bool) {
for _, commentType := range []commentType{headComment, lineComment} {
tmpl.copyComment(dst, src, commentType, isMapElement)
}
}
func (tmpl *Template) copyListEntryComments(list string, dstIdx, srcIdx int, field string) {
dst, src := listFields(list, dstIdx, srcIdx, field)
isMapElement := field != ""
tmpl.copyComments(dst, src, isMapElement)
}
func (tmpl *Template) deleteListEntry(list string, idx int) {
tmpl.expr.WriteString(fmt.Sprintf("| del($a.%s[%d], $b.%s[%d])\n", list, idx, list, idx))
}
func (tmpl *Template) upgradeListEntryStringToMapField(list string, idx int, field string) {
tmpl.expr.WriteString(fmt.Sprintf("| ($a.%s[%d] | select(type == \"!!str\")) |= {\"%s\": .}\n", list, idx, field))
}
func (tmpl *Template) combineListEntries() error {
if err := tmpl.Unmarshal(); err != nil {
return err
}
tmpl.combineAdditionalDisks()
tmpl.combineMounts()
tmpl.combineNetworks()
return tmpl.evalExpr()
}
func (tmpl *Template) combineAdditionalDisks() {
const additionalDisks = "additionalDisks"
diskIdx := make(map[string]int, len(tmpl.Config.AdditionalDisks))
for src := 0; src < len(tmpl.Config.AdditionalDisks); {
disk := tmpl.Config.AdditionalDisks[src]
var from, to int
if disk.Name == "*" {
from = 0
to = src - 1
} else {
if i, ok := diskIdx[disk.Name]; ok {
from = i
to = i
} else {
if disk.Name != "" {
diskIdx[disk.Name] = src
}
src++
continue
}
}
for dst := from; dst <= to; dst++ {
upgradeDiskToMap := sync.OnceFunc(func() {
tmpl.upgradeListEntryStringToMapField(additionalDisks, dst, "name")
})
dest := &tmpl.Config.AdditionalDisks[dst]
if dest.Format == nil && disk.Format != nil {
upgradeDiskToMap()
tmpl.copyListEntryField(additionalDisks, dst, src, "format")
dest.Format = disk.Format
}
if dest.FSType == nil && disk.FSType != nil {
upgradeDiskToMap()
tmpl.copyListEntryField(additionalDisks, dst, src, "fsType")
dest.FSType = disk.FSType
}
if len(dest.FSArgs) == 0 && len(disk.FSArgs) != 0 {
upgradeDiskToMap()
tmpl.copyListEntryField(additionalDisks, dst, src, "fsArgs")
dest.FSArgs = disk.FSArgs
}
if disk.Name != "*" {
tmpl.copyListEntryComments(additionalDisks, dst, src, "")
}
}
tmpl.Config.AdditionalDisks = slices.Delete(tmpl.Config.AdditionalDisks, src, src+1)
tmpl.deleteListEntry(additionalDisks, src)
}
}
func (tmpl *Template) combineMounts() {
const mounts = "mounts"
mountPointIdx := make(map[string]int, len(tmpl.Config.Mounts))
for src := 0; src < len(tmpl.Config.Mounts); {
mount := tmpl.Config.Mounts[src]
mountPoint := mount.Location
if mount.MountPoint != nil {
mountPoint = *mount.MountPoint
}
var from, to int
if mountPoint == "*" {
from = 0
to = src - 1
} else {
if i, ok := mountPointIdx[mountPoint]; ok {
from = i
to = i
} else {
if mountPoint != "" {
mountPointIdx[mountPoint] = src
}
src++
continue
}
}
for dst := from; dst <= to; dst++ {
dest := &tmpl.Config.Mounts[dst]
if dest.MountPoint == nil && mount.MountPoint != nil {
tmpl.copyListEntryField(mounts, dst, src, "mountPoint")
dest.MountPoint = mount.MountPoint
}
if dest.Writable == nil && mount.Writable != nil {
tmpl.copyListEntryField(mounts, dst, src, "writable")
dest.Writable = mount.Writable
}
if dest.SSHFS.Cache == nil && mount.SSHFS.Cache != nil {
tmpl.copyListEntryField(mounts, dst, src, "sshfs.cache")
dest.SSHFS.Cache = mount.SSHFS.Cache
}
if dest.SSHFS.FollowSymlinks == nil && mount.SSHFS.FollowSymlinks != nil {
tmpl.copyListEntryField(mounts, dst, src, "sshfs.followSymlinks")
dest.SSHFS.FollowSymlinks = mount.SSHFS.FollowSymlinks
}
if dest.SSHFS.SFTPDriver == nil && mount.SSHFS.SFTPDriver != nil {
tmpl.copyListEntryField(mounts, dst, src, "sshfs.sftpDriver")
dest.SSHFS.SFTPDriver = mount.SSHFS.SFTPDriver
}
if dest.NineP.SecurityModel == nil && mount.NineP.SecurityModel != nil {
tmpl.copyListEntryField(mounts, dst, src, "9p.securityModel")
dest.NineP.SecurityModel = mount.NineP.SecurityModel
}
if dest.NineP.ProtocolVersion == nil && mount.NineP.ProtocolVersion != nil {
tmpl.copyListEntryField(mounts, dst, src, "9p.protocolVersion")
dest.NineP.ProtocolVersion = mount.NineP.ProtocolVersion
}
if dest.NineP.Msize == nil && mount.NineP.Msize != nil {
tmpl.copyListEntryField(mounts, dst, src, "9p.msize")
dest.NineP.Msize = mount.NineP.Msize
}
if dest.NineP.Cache == nil && mount.NineP.Cache != nil {
tmpl.copyListEntryField(mounts, dst, src, "9p.cache")
dest.NineP.Cache = mount.NineP.Cache
}
if dest.Virtiofs.QueueSize == nil && mount.Virtiofs.QueueSize != nil {
tmpl.copyListEntryField(mounts, dst, src, "virtiofs.queueSize")
dest.Virtiofs.QueueSize = mount.Virtiofs.QueueSize
}
if mountPoint != "*" {
tmpl.copyListEntryComments(mounts, dst, src, "")
tmpl.copyListEntryComments(mounts, dst, src, "sshfs")
tmpl.copyListEntryComments(mounts, dst, src, "9p")
tmpl.copyListEntryComments(mounts, dst, src, "virtiofs")
}
}
tmpl.Config.Mounts = slices.Delete(tmpl.Config.Mounts, src, src+1)
tmpl.deleteListEntry(mounts, src)
}
}
func (tmpl *Template) combineNetworks() {
const networks = "networks"
interfaceIdx := make(map[string]int, len(tmpl.Config.Networks))
for src := 0; src < len(tmpl.Config.Networks); {
nw := tmpl.Config.Networks[src]
var from, to int
if nw.Interface == "*" {
from = 0
to = src - 1
} else {
if i, ok := interfaceIdx[nw.Interface]; ok {
from = i
to = i
} else {
if nw.Interface != "" {
interfaceIdx[nw.Interface] = src
}
src++
continue
}
}
for dst := from; dst <= to; dst++ {
dest := &tmpl.Config.Networks[dst]
if dest.Lima == "" && dest.Socket == "" {
if nw.Lima != "" {
tmpl.copyListEntryField(networks, dst, src, "lima")
dest.Lima = nw.Lima
}
if nw.Socket != "" {
tmpl.copyListEntryField(networks, dst, src, "socket")
dest.Socket = nw.Socket
}
}
if dest.MACAddress == "" && nw.MACAddress != "" {
tmpl.copyListEntryField(networks, dst, src, "macAddress")
dest.MACAddress = nw.MACAddress
}
if dest.VZNAT == nil && nw.VZNAT != nil {
tmpl.copyListEntryField(networks, dst, src, "vzNAT")
dest.VZNAT = nw.VZNAT
}
if dest.Metric == nil && nw.Metric != nil {
tmpl.copyListEntryField(networks, dst, src, "metric")
dest.Metric = nw.Metric
}
if nw.Interface != "*" {
tmpl.copyListEntryComments(networks, dst, src, "")
}
}
tmpl.Config.Networks = slices.Delete(tmpl.Config.Networks, src, src+1)
tmpl.deleteListEntry(networks, src)
}
}
var maxLineLength = 65000
func encodeScriptReason(script string) string {
start := 0
line := 1
for i, r := range script {
if !(unicode.IsPrint(r) || r == '\n') {
return fmt.Sprintf("unprintable character %q at offset %d", r, i)
}
if i-start >= maxLineLength {
return fmt.Sprintf("line %d (offset %d) is longer than %d characters", line, start, maxLineLength)
}
if r == '\n' {
line++
start = i + 1
}
}
return ""
}
const base64ChunkLength = 76
func binaryString(s string) string {
encoded := base64.StdEncoding.EncodeToString([]byte(s))
if len(encoded) <= base64ChunkLength {
return encoded
}
lineCount := (len(encoded) + base64ChunkLength - 1) / base64ChunkLength
builder := strings.Builder{}
builder.Grow(len(encoded) + lineCount)
for i := 0; i < len(encoded); i += base64ChunkLength {
end := min(i+base64ChunkLength, len(encoded))
builder.WriteString(encoded[i:end])
builder.WriteByte('\n')
}
return builder.String()
}
func (tmpl *Template) updateScript(field string, idx int, newName, script, file string) {
tag := ""
if reason := encodeScriptReason(script); reason != "" {
logrus.Infof("File %q is being base64 encoded: %s", file, reason)
script = binaryString(script)
tag = "!!binary"
}
entry := fmt.Sprintf("$a.%s[%d].file", field, idx)
tmpl.expr.WriteString(fmt.Sprintf("| (%s) = %q | (%s) tag = %q | (%s | key) = %q\n",
entry, script, entry, tag, entry, newName))
}
func (tmpl *Template) embedAllScripts(ctx context.Context, embedAll bool) error {
if err := tmpl.Unmarshal(); err != nil {
return err
}
for i, p := range tmpl.Config.Probes {
if p.File == nil {
continue
}
if p.Script == nil || *p.Script != "" {
continue
}
isTemplate, _ := SeemsTemplateURL(p.File.URL)
if embedAll || !isTemplate {
scriptTmpl, err := Read(ctx, "", p.File.URL)
if err != nil {
return err
}
tmpl.updateScript("probes", i, "script", string(scriptTmpl.Bytes), p.File.URL)
}
}
for i, p := range tmpl.Config.Provision {
if p.File == nil {
continue
}
newName := "script"
switch p.Mode {
case limatype.ProvisionModeData:
newName = "content"
if p.Content != nil {
continue
}
case limatype.ProvisionModeYQ:
newName = "expression"
if p.Expression != nil {
continue
}
default:
if p.Script != nil && *p.Script != "" {
continue
}
}
isTemplate, _ := SeemsTemplateURL(p.File.URL)
if embedAll || !isTemplate {
scriptTmpl, err := Read(ctx, "", p.File.URL)
if err != nil {
return err
}
tmpl.updateScript("provision", i, newName, string(scriptTmpl.Bytes), p.File.URL)
}
}
return tmpl.evalExpr()
}