package vm
import (
"fmt"
"reflect"
"strings"
"github.com/grafana/agent/pkg/river/ast"
"github.com/grafana/agent/pkg/river/diag"
"github.com/grafana/agent/pkg/river/internal/reflectutil"
"github.com/grafana/agent/pkg/river/internal/rivertags"
"github.com/grafana/agent/pkg/river/internal/value"
)
type structDecoder struct {
VM *Evaluator
Scope *Scope
Assoc map[value.Value]ast.Node
TagInfo *tagInfo
}
func (st *structDecoder) Decode(stmts ast.Body, rv reflect.Value) error {
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("river/vm: structDecoder expects struct, got %s", rv.Kind()))
}
state := decodeOptions{
Tags: st.TagInfo.TagLookup,
EnumBlocks: st.TagInfo.EnumLookup,
SeenAttrs: make(map[string]struct{}),
SeenBlocks: make(map[string]struct{}),
SeenEnums: make(map[string]struct{}),
BlockCount: make(map[string]int),
BlockIndex: make(map[*ast.BlockStmt]int),
EnumCount: make(map[string]int),
EnumIndex: make(map[*ast.BlockStmt]int),
}
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case *ast.BlockStmt:
fullName := strings.Join(stmt.Name, ".")
if enumTf, isEnum := st.TagInfo.EnumLookup[fullName]; isEnum {
enumName := strings.Join(enumTf.EnumField.Name, ".")
state.EnumIndex[stmt] = state.EnumCount[enumName]
state.EnumCount[enumName]++
} else {
state.BlockIndex[stmt] = state.BlockCount[fullName]
state.BlockCount[fullName]++
}
}
}
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case *ast.AttributeStmt:
if err := st.decodeAttr(stmt, rv, &state); err != nil {
return err
}
case *ast.BlockStmt:
if err := st.decodeBlock(stmt, rv, &state); err != nil {
return err
}
default:
panic(fmt.Sprintf("river/vm: unrecognized node type %T", stmt))
}
}
for _, tf := range st.TagInfo.Tags {
if tf.IsOptional() {
continue
}
fullName := strings.Join(tf.Name, ".")
switch {
case tf.IsAttr():
if _, consumed := state.SeenAttrs[fullName]; !consumed {
return fmt.Errorf("missing required attribute %q", fullName)
}
case tf.IsBlock():
if _, consumed := state.SeenBlocks[fullName]; !consumed {
return fmt.Errorf("missing required block %q", fullName)
}
}
}
return nil
}
type decodeOptions struct {
Tags map[string]rivertags.Field
EnumBlocks map[string]enumBlock
SeenAttrs, SeenBlocks, SeenEnums map[string]struct{}
BlockCount map[string]int
BlockIndex map[*ast.BlockStmt]int
EnumCount map[string]int
EnumIndex map[*ast.BlockStmt]int
}
func (st *structDecoder) decodeAttr(attr *ast.AttributeStmt, rv reflect.Value, state *decodeOptions) error {
fullName := attr.Name.Name
if _, seen := state.SeenAttrs[fullName]; seen {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(attr).Position(),
EndPos: ast.EndPos(attr).Position(),
Message: fmt.Sprintf("attribute %q may only be provided once", fullName),
}}
}
state.SeenAttrs[fullName] = struct{}{}
tf, ok := state.Tags[fullName]
if !ok {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(attr).Position(),
EndPos: ast.EndPos(attr).Position(),
Message: fmt.Sprintf("unrecognized attribute name %q", fullName),
}}
} else if tf.IsBlock() {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(attr).Position(),
EndPos: ast.EndPos(attr).Position(),
Message: fmt.Sprintf("%q must be a block, but is used as an attribute", fullName),
}}
}
val, err := st.VM.evaluateExpr(st.Scope, st.Assoc, attr.Value)
if err != nil {
return err
}
field := reflectutil.GetOrAlloc(rv, tf)
if err := value.Decode(val, field.Addr().Interface()); err != nil {
return err
}
return nil
}
func (st *structDecoder) decodeBlock(block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error {
fullName := block.GetBlockName()
if _, isEnum := state.EnumBlocks[fullName]; isEnum {
return st.decodeEnumBlock(fullName, block, rv, state)
}
return st.decodeNormalBlock(fullName, block, rv, state)
}
func (st *structDecoder) decodeNormalBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error {
tf, isBlock := state.Tags[fullName]
if !isBlock {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(block).Position(),
EndPos: ast.EndPos(block).Position(),
Message: fmt.Sprintf("unrecognized block name %q", fullName),
}}
} else if tf.IsAttr() {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(block).Position(),
EndPos: ast.EndPos(block).Position(),
Message: fmt.Sprintf("%q must be an attribute, but is used as a block", fullName),
}}
}
field := reflectutil.GetOrAlloc(rv, tf)
decodeField := prepareDecodeValue(field)
switch decodeField.Kind() {
case reflect.Slice:
if _, seen := state.SeenBlocks[fullName]; !seen {
count := state.BlockCount[fullName]
decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count))
}
blockIndex, ok := state.BlockIndex[block]
if !ok {
panic("river/vm: block not found in index lookup table")
}
decodeElement := prepareDecodeValue(decodeField.Index(blockIndex))
err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement)
if err != nil {
return err
}
case reflect.Array:
if decodeField.Len() != state.BlockCount[fullName] {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(block).Position(),
EndPos: ast.EndPos(block).Position(),
Message: fmt.Sprintf(
"block %q must be specified exactly %d times, but was specified %d times",
fullName,
decodeField.Len(),
state.BlockCount[fullName],
),
}}
}
blockIndex, ok := state.BlockIndex[block]
if !ok {
panic("river/vm: block not found in index lookup table")
}
decodeElement := prepareDecodeValue(decodeField.Index(blockIndex))
err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement)
if err != nil {
return err
}
default:
if state.BlockCount[fullName] > 1 {
return diag.Diagnostics{{
Severity: diag.SeverityLevelError,
StartPos: ast.StartPos(block).Position(),
EndPos: ast.EndPos(block).Position(),
Message: fmt.Sprintf("block %q may only be specified once", fullName),
}}
}
err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeField)
if err != nil {
return err
}
}
state.SeenBlocks[fullName] = struct{}{}
return nil
}
func (st *structDecoder) decodeEnumBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error {
tf, ok := state.EnumBlocks[fullName]
if !ok {
panic("decodeEnumBlock called with a non-enum block")
}
enumName := strings.Join(tf.EnumField.Name, ".")
enumField := reflectutil.GetOrAlloc(rv, tf.EnumField)
decodeField := prepareDecodeValue(enumField)
if decodeField.Kind() != reflect.Slice {
panic("river/vm: enum field must be a slice kind, got " + decodeField.Kind().String())
}
if _, seen := state.SeenEnums[enumName]; !seen {
count := state.EnumCount[enumName]
decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count))
}
state.SeenEnums[enumName] = struct{}{}
enumIndex, ok := state.EnumIndex[block]
if !ok {
panic("river/vm: enum block not found in index lookup table")
}
enumElement := prepareDecodeValue(decodeField.Index(enumIndex))
enumBlock := reflectutil.GetOrAlloc(enumElement, tf.BlockField)
decodeBlock := prepareDecodeValue(enumBlock)
return st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeBlock)
}