package utils
import (
"bufio"
"bytes"
"errors"
"fmt"
"regexp"
)
const fuzzyWindow = 50
var hunkHeaderRe = regexp.MustCompile(`^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@`)
type diffOp struct {
kind byte
line []byte
}
type diffHunk struct {
oldStart int
ops []diffOp
}
func ApplyUnifiedDiff(orig, unified []byte) ([]byte, error) {
origLines := splitLinesKeepNL(orig)
hunks, err := ParseUnifiedDiff(unified)
if err != nil {
return nil, err
}
if len(hunks) == 0 {
return orig, nil
}
var out [][]byte
origIdx := 0
for _, hk := range hunks {
preferredIdx := hk.oldStart - 1
if preferredIdx < 0 {
return nil, fmt.Errorf("invalid hunk oldStart=%d", hk.oldStart)
}
if preferredIdx > len(origLines) {
return nil, fmt.Errorf("hunk starts past EOF: start=%d len=%d", hk.oldStart, len(origLines))
}
if preferredIdx < origIdx {
found, err2 := findFuzzyHunkStart(origLines, origIdx, preferredIdx, hk.ops)
if err2 != nil {
return nil, fmt.Errorf("overlapping or out-of-order hunks: oldStart=%d", hk.oldStart)
}
out = append(out, origLines[origIdx:found]...)
origIdx = found
out, origIdx, err = applyHunkIntoOut(origLines, out, origIdx, hk.ops)
if err != nil {
return nil, err
}
continue
}
applyIdx := -1
if preferredIdx >= origIdx {
_, _, errTry := applyHunkAt(origLines, preferredIdx, hk.ops)
if errTry == nil {
applyIdx = preferredIdx
}
}
if applyIdx == -1 {
found, err2 := findFuzzyHunkStart(origLines, origIdx, preferredIdx, hk.ops)
if err2 != nil {
if preferredIdx >= origIdx {
_, _, errTry := applyHunkAt(origLines, preferredIdx, hk.ops)
if errTry != nil {
return nil, fmt.Errorf("%v (and fuzzy placement failed: %v)", errTry, err2)
}
}
return nil, fmt.Errorf("fuzzy placement failed: %v", err2)
}
applyIdx = found
}
if applyIdx < origIdx {
return nil, fmt.Errorf("fuzzy placement violated ordering: applyIdx=%d origIdx=%d", applyIdx, origIdx)
}
out = append(out, origLines[origIdx:applyIdx]...)
origIdx = applyIdx
out, origIdx, err = applyHunkIntoOut(origLines, out, origIdx, hk.ops)
if err != nil {
return nil, err
}
}
out = append(out, origLines[origIdx:]...)
return bytes.Join(out, nil), nil
}
func ParseUnifiedDiff(unified []byte) ([]diffHunk, error) {
sc := bufio.NewScanner(bytes.NewReader(unified))
const max = 10 * 1024 * 1024
sc.Buffer(make([]byte, 0, 64*1024), max)
var (
hunks []diffHunk
cur *diffHunk
lastOp *diffOp
)
lineNo := 0
for sc.Scan() {
lineNo++
line := sc.Bytes()
if bytes.HasPrefix(line, []byte("diff ")) ||
bytes.HasPrefix(line, []byte("index ")) ||
bytes.HasPrefix(line, []byte("--- ")) ||
bytes.HasPrefix(line, []byte("+++ ")) {
continue
}
if m := hunkHeaderRe.FindSubmatch(line); m != nil {
oldStart, err := atoiBytes(m[1])
if err != nil {
return nil, fmt.Errorf("invalid hunk header at diff line %d: %w", lineNo, err)
}
hunks = append(hunks, diffHunk{oldStart: oldStart})
cur = &hunks[len(hunks)-1]
lastOp = nil
continue
}
if cur == nil {
if len(bytes.TrimSpace(line)) == 0 {
continue
}
return nil, fmt.Errorf("invalid unified diff: missing hunk header (diff line %d: %q)", lineNo, string(line))
}
if bytes.HasPrefix(line, []byte(`\ No newline at end of file`)) {
if lastOp != nil && len(lastOp.line) > 0 && lastOp.line[len(lastOp.line)-1] == '\n' {
lastOp.line = lastOp.line[:len(lastOp.line)-1]
}
continue
}
if len(line) == 0 {
return nil, fmt.Errorf("invalid unified diff: empty line without prefix (diff line %d)", lineNo)
}
prefix := line[0]
if prefix != ' ' && prefix != '+' && prefix != '-' {
return nil, fmt.Errorf("invalid diff line prefix %q at diff line %d: %q", prefix, lineNo, string(line))
}
ln := append([]byte(nil), line...)
ln = append(ln, '\n')
content := append([]byte(nil), ln[1:]...)
cur.ops = append(cur.ops, diffOp{kind: prefix, line: content})
lastOp = &cur.ops[len(cur.ops)-1]
}
if err := sc.Err(); err != nil {
return nil, err
}
return hunks, nil
}
func splitLinesKeepNL(b []byte) [][]byte {
if len(b) == 0 {
return nil
}
var lines [][]byte
start := 0
for i := 0; i < len(b); i++ {
if b[i] == '\n' {
lines = append(lines, b[start:i+1])
start = i + 1
}
}
if start < len(b) {
lines = append(lines, b[start:])
}
return lines
}
func atoiBytes(b []byte) (int, error) {
n := 0
for _, c := range b {
if c < '0' || c > '9' {
return 0, errors.New("non-digit")
}
n = n*10 + int(c-'0')
}
return n, nil
}
func equalLine(origLine, diffLine []byte, isEOFLine bool) bool {
if bytes.Equal(origLine, diffLine) {
return true
}
if !isEOFLine {
return false
}
if len(origLine) > 0 && origLine[len(origLine)-1] != '\n' &&
len(diffLine) > 0 && diffLine[len(diffLine)-1] == '\n' {
return bytes.Equal(origLine, diffLine[:len(diffLine)-1])
}
return false
}
func equalLineContext(origLine, diffLine []byte) bool {
if bytes.Equal(origLine, diffLine) {
return true
}
return bytes.Equal(normalizeForContextCompare(origLine), normalizeForContextCompare(diffLine))
}
func normalizeForContextCompare(line []byte) []byte {
b := line
if len(b) > 0 && b[len(b)-1] == '\n' {
b = b[:len(b)-1]
}
if len(b) > 0 && b[len(b)-1] == '\r' {
b = b[:len(b)-1]
}
i := len(b)
for i > 0 {
c := b[i-1]
if c == ' ' || c == '\t' {
i--
continue
}
break
}
b = b[:i]
return b
}
func applyHunkAt(origLines [][]byte, startIdx int, ops []diffOp) (int, int, error) {
origIdx := startIdx
for _, o := range ops {
switch o.kind {
case ' ':
if origIdx >= len(origLines) {
return startIdx, origIdx, errors.New("patch context extends past EOF")
}
if !equalLineContext(origLines[origIdx], o.line) {
return startIdx, origIdx, fmt.Errorf("patch context mismatch at line %d", origIdx+1)
}
origIdx++
case '-':
if origIdx >= len(origLines) {
return startIdx, origIdx, errors.New("patch deletion extends past EOF")
}
isEOF := origIdx == len(origLines)-1
if !equalLine(origLines[origIdx], o.line, isEOF) {
return startIdx, origIdx, fmt.Errorf("patch deletion mismatch at line %d", origIdx+1)
}
origIdx++
case '+':
default:
return startIdx, origIdx, fmt.Errorf("unknown diff op %q", o.kind)
}
}
return startIdx, origIdx, nil
}
func applyHunkIntoOut(origLines [][]byte, out [][]byte, origIdx int, ops []diffOp) ([][]byte, int, error) {
for _, o := range ops {
switch o.kind {
case ' ':
if origIdx >= len(origLines) {
return nil, 0, errors.New("patch context extends past EOF")
}
if !equalLineContext(origLines[origIdx], o.line) {
return nil, 0, fmt.Errorf("patch context mismatch at line %d", origIdx+1)
}
out = append(out, origLines[origIdx])
origIdx++
case '-':
if origIdx >= len(origLines) {
return nil, 0, errors.New("patch deletion extends past EOF")
}
isEOF := origIdx == len(origLines)-1
if !equalLine(origLines[origIdx], o.line, isEOF) {
return nil, 0, fmt.Errorf("patch deletion mismatch at line %d", origIdx+1)
}
origIdx++
case '+':
out = append(out, o.line)
default:
return nil, 0, fmt.Errorf("unknown diff op %q", o.kind)
}
}
return out, origIdx, nil
}
func findFuzzyHunkStart(origLines [][]byte, minIdx, preferredIdx int, ops []diffOp) (int, error) {
lo := preferredIdx - fuzzyWindow
hi := preferredIdx + fuzzyWindow
if lo < minIdx {
lo = minIdx
}
if hi > len(origLines) {
hi = len(origLines)
}
if lo < 0 {
lo = 0
}
if len(origLines) <= 2*fuzzyWindow {
lo = minIdx
hi = len(origLines)
}
bestIdx := -1
bestDist := 0
for i := lo; i < hi; i++ {
_, _, err := applyHunkAt(origLines, i, ops)
if err != nil {
continue
}
dist := abs(i - preferredIdx)
if bestIdx == -1 || dist < bestDist || (dist == bestDist && i < bestIdx) {
bestIdx = i
bestDist = dist
}
}
if bestIdx == -1 {
return -1, fmt.Errorf("no applicable hunk location found in [%d,%d) around preferred=%d", lo, hi, preferredIdx)
}
return bestIdx, nil
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}