Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/agent/core/policy.go
3436 views
1
package core
2
3
import (
4
"errors"
5
"fmt"
6
"github.com/kardolus/chatgpt-cli/agent/types"
7
"go.uber.org/zap"
8
"path/filepath"
9
"strings"
10
)
11
12
type Policy interface {
13
AllowStep(cfg types.Config, step types.Step) error
14
}
15
16
const (
17
PolicyKindStepType = "step_type"
18
PolicyKindShell = "shell"
19
PolicyKindLLM = "llm"
20
PolicyKindFiles = "files"
21
PolicyKindPathEscape = "path_escape"
22
)
23
24
type DefaultPolicy struct {
25
limits PolicyLimits
26
}
27
28
type PolicyLimits struct {
29
AllowedTools []types.ToolKind
30
RestrictFilesToWorkDir bool
31
DeniedShellCommands []string
32
AllowedFileOps []string
33
}
34
35
func NewDefaultPolicy(limits PolicyLimits) *DefaultPolicy {
36
return &DefaultPolicy{limits: limits}
37
}
38
39
func (p *DefaultPolicy) AllowStep(cfg types.Config, step types.Step) error {
40
switch step.Type {
41
case types.ToolShell, types.ToolLLM, types.ToolFiles:
42
// ok
43
default:
44
return PolicyDeniedError{
45
Kind: PolicyKindStepType,
46
Reason: fmt.Sprintf("unsupported step type: %s", step.Type),
47
}
48
}
49
50
if len(p.limits.AllowedTools) > 0 && !containsTool(p.limits.AllowedTools, step.Type) {
51
return PolicyDeniedError{
52
Kind: PolicyKindStepType,
53
Reason: fmt.Sprintf("tool not allowed: %s", step.Type),
54
}
55
}
56
57
switch step.Type {
58
case types.ToolShell:
59
cmd := strings.TrimSpace(step.Command)
60
if cmd == "" {
61
return PolicyDeniedError{Kind: PolicyKindShell, Reason: "shell step requires Command"}
62
}
63
if len(p.limits.DeniedShellCommands) > 0 && containsString(p.limits.DeniedShellCommands, cmd) {
64
return PolicyDeniedError{Kind: PolicyKindShell, Reason: fmt.Sprintf("shell command denied: %s", cmd)}
65
}
66
67
if p.limits.RestrictFilesToWorkDir && cfg.WorkDir != "" {
68
if err := denyShellArgsOutsideWorkDir(cfg.WorkDir, step.Args); err != nil {
69
return err
70
}
71
}
72
73
case types.ToolLLM:
74
if strings.TrimSpace(step.Prompt) == "" {
75
return PolicyDeniedError{Kind: PolicyKindLLM, Reason: "llm step requires Prompt"}
76
}
77
78
case types.ToolFiles:
79
op := strings.ToLower(strings.TrimSpace(step.Op))
80
if op == "" {
81
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: "file step requires Op"}
82
}
83
if strings.TrimSpace(step.Path) == "" {
84
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: "file step requires Path"}
85
}
86
87
switch op {
88
case "patch":
89
if strings.TrimSpace(step.Data) == "" {
90
return PolicyDeniedError{
91
Kind: PolicyKindFiles,
92
Reason: "patch requires Data (unified diff)",
93
}
94
}
95
case "replace":
96
if len(step.Old) == 0 {
97
return PolicyDeniedError{
98
Kind: PolicyKindFiles,
99
Reason: "replace requires Old pattern",
100
}
101
}
102
}
103
104
if !fileOpAllowed(p.limits.AllowedFileOps, op) {
105
return PolicyDeniedError{Kind: PolicyKindFiles, Reason: fmt.Sprintf("file op not allowed: %s", op)}
106
}
107
108
if p.limits.RestrictFilesToWorkDir && cfg.WorkDir != "" {
109
if escapesWorkDir(cfg.WorkDir, step.Path) {
110
return PolicyDeniedError{
111
Kind: PolicyKindPathEscape,
112
Reason: fmt.Sprintf("path escapes workdir: workdir=%q path=%q", cfg.WorkDir, step.Path),
113
}
114
}
115
}
116
}
117
118
return nil
119
}
120
121
// denyShellArgsOutsideWorkDir blocks absolute paths, ~, .., and any arg that would escape workdir.
122
func denyShellArgsOutsideWorkDir(workdir string, args []string) error {
123
for _, raw := range args {
124
s := strings.TrimSpace(raw)
125
if s == "" {
126
continue
127
}
128
129
// Hard block obvious escapes
130
if strings.HasPrefix(s, "~") {
131
return PolicyDeniedError{
132
Kind: PolicyKindPathEscape,
133
Reason: fmt.Sprintf("shell arg escapes workdir: workdir=%q arg=%q", workdir, s),
134
}
135
}
136
137
// If it contains path separators or looks like a path, validate it.
138
// (This is intentionally conservative; better to block than allow rm /tmp.)
139
if strings.Contains(s, "/") || strings.Contains(s, `\`) || strings.Contains(s, "..") || filepath.IsAbs(s) {
140
if escapesWorkDir(workdir, s) {
141
return PolicyDeniedError{
142
Kind: PolicyKindPathEscape,
143
Reason: fmt.Sprintf("shell arg escapes workdir: workdir=%q arg=%q", workdir, s),
144
}
145
}
146
}
147
}
148
return nil
149
}
150
151
// PolicyDeniedError is a typed error so Agent/Planner can branch on it.
152
type PolicyDeniedError struct {
153
Kind string
154
Reason string
155
}
156
157
func (e PolicyDeniedError) Error() string {
158
return fmt.Sprintf("policy denied: kind=%s reason=%s", e.Kind, e.Reason)
159
}
160
161
func IsPolicyStop(err error, log *zap.SugaredLogger) bool {
162
var pe PolicyDeniedError
163
if errors.As(err, &pe) {
164
log.Warnf("Policy denied (kind=%s): %v", pe.Kind, err)
165
return true
166
}
167
return false
168
}
169
170
func containsTool(xs []types.ToolKind, k types.ToolKind) bool {
171
for _, x := range xs {
172
if x == k {
173
return true
174
}
175
}
176
return false
177
}
178
179
func containsString(xs []string, s string) bool {
180
for _, x := range xs {
181
if x == s {
182
return true
183
}
184
}
185
return false
186
}
187
188
// escapesWorkDir returns true if path, when resolved relative to workdir, is outside workdir.
189
func escapesWorkDir(workdir, path string) bool {
190
wd, err := filepath.Abs(workdir)
191
if err != nil {
192
return true
193
}
194
wd = filepath.Clean(wd)
195
196
if r, err := filepath.EvalSymlinks(wd); err == nil {
197
wd = r
198
}
199
200
var full string
201
if filepath.IsAbs(path) {
202
full = path
203
} else {
204
full = filepath.Join(wd, path)
205
}
206
full, err = filepath.Abs(full)
207
if err != nil {
208
return true
209
}
210
full = filepath.Clean(full)
211
212
if r, err := filepath.EvalSymlinks(full); err == nil {
213
full = r
214
}
215
216
if full == wd {
217
return false
218
}
219
prefix := wd + string(filepath.Separator)
220
return !strings.HasPrefix(full, prefix)
221
}
222
223
func fileOpAllowed(allowed []string, op string) bool {
224
if len(allowed) == 0 {
225
return true
226
}
227
if containsString(allowed, op) {
228
return true
229
}
230
// If you can write arbitrary bytes, you can also patch/replace.
231
if (op == "patch" || op == "replace") && containsString(allowed, "write") {
232
return true
233
}
234
return false
235
}
236
237