Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/agent/utils/unified_diff.go
3434 views
1
package utils
2
3
import (
4
"bufio"
5
"bytes"
6
"errors"
7
"fmt"
8
"regexp"
9
)
10
11
const fuzzyWindow = 50 // lines around preferredIdx to search
12
13
var hunkHeaderRe = regexp.MustCompile(`^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@`)
14
15
type diffOp struct {
16
kind byte // ' ', '+', '-'
17
line []byte // content WITHOUT the prefix; includes trailing newline if present in diff
18
}
19
20
type diffHunk struct {
21
oldStart int // 1-based
22
ops []diffOp
23
}
24
25
func ApplyUnifiedDiff(orig, unified []byte) ([]byte, error) {
26
origLines := splitLinesKeepNL(orig)
27
28
hunks, err := ParseUnifiedDiff(unified)
29
if err != nil {
30
return nil, err
31
}
32
if len(hunks) == 0 {
33
return orig, nil
34
}
35
36
var out [][]byte
37
origIdx := 0 // 0-based into origLines
38
39
for _, hk := range hunks {
40
preferredIdx := hk.oldStart - 1
41
if preferredIdx < 0 {
42
return nil, fmt.Errorf("invalid hunk oldStart=%d", hk.oldStart)
43
}
44
if preferredIdx > len(origLines) {
45
return nil, fmt.Errorf("hunk starts past EOF: start=%d len=%d", hk.oldStart, len(origLines))
46
}
47
48
if preferredIdx < origIdx {
49
found, err2 := findFuzzyHunkStart(origLines, origIdx, preferredIdx, hk.ops)
50
if err2 != nil {
51
return nil, fmt.Errorf("overlapping or out-of-order hunks: oldStart=%d", hk.oldStart)
52
}
53
54
// Copy untouched lines up to chosen apply index.
55
out = append(out, origLines[origIdx:found]...)
56
origIdx = found
57
58
// Apply ops into output.
59
out, origIdx, err = applyHunkIntoOut(origLines, out, origIdx, hk.ops)
60
if err != nil {
61
return nil, err
62
}
63
continue
64
}
65
66
applyIdx := -1
67
68
// Try preferred only if it doesn’t violate ordering.
69
if preferredIdx >= origIdx {
70
_, _, errTry := applyHunkAt(origLines, preferredIdx, hk.ops)
71
if errTry == nil {
72
applyIdx = preferredIdx
73
}
74
}
75
76
if applyIdx == -1 {
77
found, err2 := findFuzzyHunkStart(origLines, origIdx, preferredIdx, hk.ops)
78
if err2 != nil {
79
if preferredIdx >= origIdx {
80
_, _, errTry := applyHunkAt(origLines, preferredIdx, hk.ops)
81
if errTry != nil {
82
return nil, fmt.Errorf("%v (and fuzzy placement failed: %v)", errTry, err2)
83
}
84
}
85
return nil, fmt.Errorf("fuzzy placement failed: %v", err2)
86
}
87
applyIdx = found
88
}
89
90
// Copy untouched lines up to chosen apply index.
91
if applyIdx < origIdx {
92
return nil, fmt.Errorf("fuzzy placement violated ordering: applyIdx=%d origIdx=%d", applyIdx, origIdx)
93
}
94
out = append(out, origLines[origIdx:applyIdx]...)
95
origIdx = applyIdx
96
97
// Apply ops into output.
98
out, origIdx, err = applyHunkIntoOut(origLines, out, origIdx, hk.ops)
99
if err != nil {
100
return nil, err
101
}
102
}
103
104
out = append(out, origLines[origIdx:]...)
105
return bytes.Join(out, nil), nil
106
}
107
108
func ParseUnifiedDiff(unified []byte) ([]diffHunk, error) {
109
sc := bufio.NewScanner(bytes.NewReader(unified))
110
111
const max = 10 * 1024 * 1024
112
sc.Buffer(make([]byte, 0, 64*1024), max)
113
114
var (
115
hunks []diffHunk
116
cur *diffHunk
117
lastOp *diffOp // Track last op so we can apply "\ No newline at end of file"
118
)
119
120
lineNo := 0
121
122
for sc.Scan() {
123
lineNo++
124
line := sc.Bytes() // no trailing '\n'
125
126
// Ignore common non-hunk headers.
127
if bytes.HasPrefix(line, []byte("diff ")) ||
128
bytes.HasPrefix(line, []byte("index ")) ||
129
bytes.HasPrefix(line, []byte("--- ")) ||
130
bytes.HasPrefix(line, []byte("+++ ")) {
131
continue
132
}
133
134
// New hunk header.
135
if m := hunkHeaderRe.FindSubmatch(line); m != nil {
136
oldStart, err := atoiBytes(m[1])
137
if err != nil {
138
return nil, fmt.Errorf("invalid hunk header at diff line %d: %w", lineNo, err)
139
}
140
hunks = append(hunks, diffHunk{oldStart: oldStart})
141
cur = &hunks[len(hunks)-1]
142
lastOp = nil
143
continue
144
}
145
146
// Allow whitespace/noise before first hunk only.
147
if cur == nil {
148
if len(bytes.TrimSpace(line)) == 0 {
149
continue
150
}
151
return nil, fmt.Errorf("invalid unified diff: missing hunk header (diff line %d: %q)", lineNo, string(line))
152
}
153
154
// IMPORTANT: this marker refers to the *previous* diff line
155
if bytes.HasPrefix(line, []byte(`\ No newline at end of file`)) {
156
if lastOp != nil && len(lastOp.line) > 0 && lastOp.line[len(lastOp.line)-1] == '\n' {
157
lastOp.line = lastOp.line[:len(lastOp.line)-1]
158
}
159
continue
160
}
161
162
// Within a hunk, every line must have a prefix (' ', '+', '-').
163
if len(line) == 0 {
164
return nil, fmt.Errorf("invalid unified diff: empty line without prefix (diff line %d)", lineNo)
165
}
166
167
prefix := line[0]
168
if prefix != ' ' && prefix != '+' && prefix != '-' {
169
return nil, fmt.Errorf("invalid diff line prefix %q at diff line %d: %q", prefix, lineNo, string(line))
170
}
171
172
// Add newline back (scanner strips it)
173
ln := append([]byte(nil), line...)
174
ln = append(ln, '\n')
175
176
// WITHOUT prefix, WITH newline
177
content := append([]byte(nil), ln[1:]...)
178
cur.ops = append(cur.ops, diffOp{kind: prefix, line: content})
179
lastOp = &cur.ops[len(cur.ops)-1]
180
}
181
182
if err := sc.Err(); err != nil {
183
return nil, err
184
}
185
186
return hunks, nil
187
}
188
189
func splitLinesKeepNL(b []byte) [][]byte {
190
if len(b) == 0 {
191
return nil
192
}
193
var lines [][]byte
194
start := 0
195
for i := 0; i < len(b); i++ {
196
if b[i] == '\n' {
197
lines = append(lines, b[start:i+1])
198
start = i + 1
199
}
200
}
201
if start < len(b) {
202
// last line without newline
203
lines = append(lines, b[start:])
204
}
205
return lines
206
}
207
208
func atoiBytes(b []byte) (int, error) {
209
n := 0
210
for _, c := range b {
211
if c < '0' || c > '9' {
212
return 0, errors.New("non-digit")
213
}
214
n = n*10 + int(c-'0')
215
}
216
return n, nil
217
}
218
219
func equalLine(origLine, diffLine []byte, isEOFLine bool) bool {
220
if bytes.Equal(origLine, diffLine) {
221
return true
222
}
223
if !isEOFLine {
224
return false
225
}
226
// If original last line has no '\n', allow diff line to include it.
227
if len(origLine) > 0 && origLine[len(origLine)-1] != '\n' &&
228
len(diffLine) > 0 && diffLine[len(diffLine)-1] == '\n' {
229
return bytes.Equal(origLine, diffLine[:len(diffLine)-1])
230
}
231
return false
232
}
233
234
func equalLineContext(origLine, diffLine []byte) bool {
235
if bytes.Equal(origLine, diffLine) {
236
return true
237
}
238
return bytes.Equal(normalizeForContextCompare(origLine), normalizeForContextCompare(diffLine))
239
}
240
241
func normalizeForContextCompare(line []byte) []byte {
242
// Work on a copy? We can just compute slices since we only trim.
243
b := line
244
245
// Drop trailing newline (context diffLine typically includes '\n'; orig may not at EOF)
246
if len(b) > 0 && b[len(b)-1] == '\n' {
247
b = b[:len(b)-1]
248
}
249
// Drop trailing CR (CRLF)
250
if len(b) > 0 && b[len(b)-1] == '\r' {
251
b = b[:len(b)-1]
252
}
253
254
// Trim trailing spaces/tabs
255
i := len(b)
256
for i > 0 {
257
c := b[i-1]
258
if c == ' ' || c == '\t' {
259
i--
260
continue
261
}
262
break
263
}
264
b = b[:i]
265
266
return b
267
}
268
269
func applyHunkAt(origLines [][]byte, startIdx int, ops []diffOp) (int, int, error) {
270
origIdx := startIdx
271
272
for _, o := range ops {
273
switch o.kind {
274
case ' ':
275
if origIdx >= len(origLines) {
276
return startIdx, origIdx, errors.New("patch context extends past EOF")
277
}
278
if !equalLineContext(origLines[origIdx], o.line) {
279
return startIdx, origIdx, fmt.Errorf("patch context mismatch at line %d", origIdx+1)
280
}
281
origIdx++
282
283
case '-':
284
if origIdx >= len(origLines) {
285
return startIdx, origIdx, errors.New("patch deletion extends past EOF")
286
}
287
isEOF := origIdx == len(origLines)-1
288
if !equalLine(origLines[origIdx], o.line, isEOF) {
289
return startIdx, origIdx, fmt.Errorf("patch deletion mismatch at line %d", origIdx+1)
290
}
291
origIdx++
292
293
case '+':
294
// insertion does not consume orig
295
default:
296
return startIdx, origIdx, fmt.Errorf("unknown diff op %q", o.kind)
297
}
298
}
299
300
return startIdx, origIdx, nil
301
}
302
303
func applyHunkIntoOut(origLines [][]byte, out [][]byte, origIdx int, ops []diffOp) ([][]byte, int, error) {
304
for _, o := range ops {
305
switch o.kind {
306
case ' ':
307
if origIdx >= len(origLines) {
308
return nil, 0, errors.New("patch context extends past EOF")
309
}
310
if !equalLineContext(origLines[origIdx], o.line) {
311
return nil, 0, fmt.Errorf("patch context mismatch at line %d", origIdx+1)
312
}
313
out = append(out, origLines[origIdx])
314
origIdx++
315
316
case '-':
317
if origIdx >= len(origLines) {
318
return nil, 0, errors.New("patch deletion extends past EOF")
319
}
320
isEOF := origIdx == len(origLines)-1
321
if !equalLine(origLines[origIdx], o.line, isEOF) {
322
return nil, 0, fmt.Errorf("patch deletion mismatch at line %d", origIdx+1)
323
}
324
origIdx++
325
326
case '+':
327
out = append(out, o.line)
328
329
default:
330
return nil, 0, fmt.Errorf("unknown diff op %q", o.kind)
331
}
332
}
333
return out, origIdx, nil
334
}
335
336
func findFuzzyHunkStart(origLines [][]byte, minIdx, preferredIdx int, ops []diffOp) (int, error) {
337
// Clamp search bounds.
338
lo := preferredIdx - fuzzyWindow
339
hi := preferredIdx + fuzzyWindow
340
if lo < minIdx {
341
lo = minIdx
342
}
343
if hi > len(origLines) {
344
hi = len(origLines)
345
}
346
if lo < 0 {
347
lo = 0
348
}
349
350
// If file is small, scan everything from minIdx.
351
if len(origLines) <= 2*fuzzyWindow {
352
lo = minIdx
353
hi = len(origLines)
354
}
355
356
bestIdx := -1
357
bestDist := 0
358
359
for i := lo; i < hi; i++ {
360
_, _, err := applyHunkAt(origLines, i, ops)
361
if err != nil {
362
continue
363
}
364
365
dist := abs(i - preferredIdx)
366
if bestIdx == -1 || dist < bestDist || (dist == bestDist && i < bestIdx) {
367
bestIdx = i
368
bestDist = dist
369
}
370
}
371
372
if bestIdx == -1 {
373
return -1, fmt.Errorf("no applicable hunk location found in [%d,%d) around preferred=%d", lo, hi, preferredIdx)
374
}
375
return bestIdx, nil
376
}
377
378
func abs(x int) int {
379
if x < 0 {
380
return -x
381
}
382
return x
383
}
384
385