package subset
import (
"fmt"
"reflect"
"gopkg.in/yaml.v2"
)
func Assert(source, target interface{}) error {
return assert(reflect.ValueOf(source), reflect.ValueOf(target))
}
func assert(source, target reflect.Value) error {
for canElem(source) {
source = source.Elem()
}
for canElem(target) {
target = target.Elem()
}
if source.Type() != target.Type() {
return &Error{Message: fmt.Sprintf("type mismatch: %T != %T", source.Interface(), target.Interface())}
}
switch source.Kind() {
case reflect.Slice, reflect.Array:
if source.Len() != target.Len() {
return &Error{Message: fmt.Sprintf("length mismatch: %d != %d", source.Len(), target.Len())}
}
for i := 0; i < source.Len(); i++ {
if err := assert(source.Index(i), target.Index(i)); err != nil {
return &Error{
Message: fmt.Sprintf("element %d", i),
Inner: err,
}
}
}
return nil
case reflect.Map:
iter := source.MapRange()
for iter.Next() {
var (
sourceElement = iter.Value()
targetElement = target.MapIndex(iter.Key())
)
if !targetElement.IsValid() {
return &Error{Message: fmt.Sprintf("missing key %v", iter.Key().Interface())}
}
if err := assert(sourceElement, targetElement); err != nil {
return &Error{
Message: fmt.Sprintf("%v", iter.Key().Interface()),
Inner: err,
}
}
}
return nil
default:
if !reflect.DeepEqual(source.Interface(), target.Interface()) {
return &Error{Message: fmt.Sprintf("%v != %v", source, target)}
}
return nil
}
}
func canElem(v reflect.Value) bool {
return v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr
}
type Error struct {
Message string
Inner error
}
func (e *Error) Error() string {
if e.Inner == nil {
return e.Message
}
return fmt.Sprintf("%s: %s", e.Message, e.Inner)
}
func (e *Error) Unwrap() error { return e.Inner }
func YAMLAssert(source, target []byte) error {
var sourceValue interface{}
if err := yaml.Unmarshal(source, &sourceValue); err != nil {
return err
}
var targetValue interface{}
if err := yaml.Unmarshal(target, &targetValue); err != nil {
return err
}
return Assert(sourceValue, targetValue)
}