package core
import (
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/agent/types"
"go.uber.org/zap"
"path/filepath"
"strings"
)
type Policy interface {
AllowStep(cfg types.Config, step types.Step) error
}
const (
PolicyKindStepType = "step_type"
PolicyKindShell = "shell"
PolicyKindLLM = "llm"
PolicyKindFiles = "files"
PolicyKindPathEscape = "path_escape"
)
type DefaultPolicy struct {
limits PolicyLimits
}
type PolicyLimits struct {
AllowedTools []types.ToolKind
RestrictFilesToWorkDir bool
DeniedShellCommands []string
AllowedFileOps []string
}
func NewDefaultPolicy(limits PolicyLimits) *DefaultPolicy {
return &DefaultPolicy{limits: limits}
}
func (p *DefaultPolicy) AllowStep(cfg types.Config, step types.Step) error {
switch step.Type {
case types.ToolShell, types.ToolLLM, types.ToolFiles:
default:
return PolicyDeniedError{
Kind: PolicyKindStepType,
Reason: fmt.Sprintf("unsupported step type: %s", step.Type),
}
}
if len(p.limits.AllowedTools) > 0 && !containsTool(p.limits.AllowedTools, step.Type) {
return PolicyDeniedError{
Kind: PolicyKindStepType,
Reason: fmt.Sprintf("tool not allowed: %s", step.Type),
}
}
switch step.Type {
case types.ToolShell:
cmd := strings.TrimSpace(step.Command)
if cmd == "" {
return PolicyDeniedError{Kind: PolicyKindShell, Reason: "shell step requires Command"}
}
if len(p.limits.DeniedShellCommands) > 0 && containsString(p.limits.DeniedShellCommands, cmd) {
return PolicyDeniedError{Kind: PolicyKindShell, Reason: fmt.Sprintf("shell command denied: %s", cmd)}
}
if p.limits.RestrictFilesToWorkDir && cfg.WorkDir != "" {
if err := denyShellArgsOutsideWorkDir(cfg.WorkDir, step.Args); err != nil {
return err
}
}
case types.ToolLLM:
if strings.TrimSpace(step.Prompt) == "" {
return PolicyDeniedError{Kind: PolicyKindLLM, Reason: "llm step requires Prompt"}
}
case types.ToolFiles:
op := strings.ToLower(strings.TrimSpace(step.Op))
if op == "" {
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: "file step requires Op"}
}
if strings.TrimSpace(step.Path) == "" {
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: "file step requires Path"}
}
switch op {
case "patch":
if strings.TrimSpace(step.Data) == "" {
return PolicyDeniedError{
Kind: PolicyKindFiles,
Reason: "patch requires Data (unified diff)",
}
}
case "replace":
if len(step.Old) == 0 {
return PolicyDeniedError{
Kind: PolicyKindFiles,
Reason: "replace requires Old pattern",
}
}
}
if !fileOpAllowed(p.limits.AllowedFileOps, op) {
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: fmt.Sprintf("file op not allowed: %s", op)}
}
if p.limits.RestrictFilesToWorkDir && cfg.WorkDir != "" {
if escapesWorkDir(cfg.WorkDir, step.Path) {
return PolicyDeniedError{
Kind: PolicyKindPathEscape,
Reason: fmt.Sprintf("path escapes workdir: workdir=%q path=%q", cfg.WorkDir, step.Path),
}
}
}
}
return nil
}
func denyShellArgsOutsideWorkDir(workdir string, args []string) error {
for _, raw := range args {
s := strings.TrimSpace(raw)
if s == "" {
continue
}
if strings.HasPrefix(s, "~") {
return PolicyDeniedError{
Kind: PolicyKindPathEscape,
Reason: fmt.Sprintf("shell arg escapes workdir: workdir=%q arg=%q", workdir, s),
}
}
if strings.Contains(s, "/") || strings.Contains(s, `\`) || strings.Contains(s, "..") || filepath.IsAbs(s) {
if escapesWorkDir(workdir, s) {
return PolicyDeniedError{
Kind: PolicyKindPathEscape,
Reason: fmt.Sprintf("shell arg escapes workdir: workdir=%q arg=%q", workdir, s),
}
}
}
}
return nil
}
type PolicyDeniedError struct {
Kind string
Reason string
}
func (e PolicyDeniedError) Error() string {
return fmt.Sprintf("policy denied: kind=%s reason=%s", e.Kind, e.Reason)
}
func IsPolicyStop(err error, log *zap.SugaredLogger) bool {
var pe PolicyDeniedError
if errors.As(err, &pe) {
log.Warnf("Policy denied (kind=%s): %v", pe.Kind, err)
return true
}
return false
}
func containsTool(xs []types.ToolKind, k types.ToolKind) bool {
for _, x := range xs {
if x == k {
return true
}
}
return false
}
func containsString(xs []string, s string) bool {
for _, x := range xs {
if x == s {
return true
}
}
return false
}
func escapesWorkDir(workdir, path string) bool {
wd, err := filepath.Abs(workdir)
if err != nil {
return true
}
wd = filepath.Clean(wd)
if r, err := filepath.EvalSymlinks(wd); err == nil {
wd = r
}
var full string
if filepath.IsAbs(path) {
full = path
} else {
full = filepath.Join(wd, path)
}
full, err = filepath.Abs(full)
if err != nil {
return true
}
full = filepath.Clean(full)
if r, err := filepath.EvalSymlinks(full); err == nil {
full = r
}
if full == wd {
return false
}
prefix := wd + string(filepath.Separator)
return !strings.HasPrefix(full, prefix)
}
func fileOpAllowed(allowed []string, op string) bool {
if len(allowed) == 0 {
return true
}
if containsString(allowed, op) {
return true
}
if (op == "patch" || op == "replace") && containsString(allowed, "write") {
return true
}
return false
}