internal/git/conflict/resolve.go (177 lines of code) (raw):
package conflict
import (
"bufio"
"bytes"
"crypto/sha1"
"errors"
"fmt"
"io"
"strings"
"gitlab.com/gitlab-org/gitaly/v16/internal/git"
)
// section denotes the various conflict sections
type section uint
const (
// Resolutions for conflict is done by clients providing selections for
// the different conflict sections. "head"/"origin" is used to denote the
// selection for ours/their trees respectively.
head = "head"
origin = "origin"
// fileLimit is used to set the limit on the buffer size for bufio.Scanner,
// we don't support conflict resolution for files which require bigger buffers.
fileLimit = 200 * (1 << 10)
)
const (
// The sections are used to define various lines of the conflicted file.
// Here Old/New is used to denote ours/their respectively.
sectionNone = section(iota)
sectionOld
sectionNew
sectionNoNewline
)
// Errors that can occur during parsing of a merge conflict file
var (
// ErrUnmergeableFile is returned when the either the file exceeds the
// fileLimit or no data was read from the file (in case of a binary file).
ErrUnmergeableFile = errors.New("merging is not supported for file")
// ErrUnexpectedDelimiter is returned when the previous section doesn't
// match the expected flow of sections.
ErrUnexpectedDelimiter = errors.New("unexpected conflict delimiter")
// ErrMissingEndDelimiter is returned when the final section parsed doesn't
// match the expected end state.
ErrMissingEndDelimiter = errors.New("missing last delimiter")
)
// line is a structure used to denote individual lines in the conflicted blob,
// with information around how that line maps to ours/theirs blobs.
type line struct {
// objIndex states the cursor position in the conflicted blob
objIndex uint
// oldIndex states the cursor position in the 'ours' blob
oldIndex uint
// oldIndex states the cursor position in the 'theirs' blob
newIndex uint
// payload denotes the content of line (sans the newline)
payload string
// crlf indicates if the line uses a CRLF line break.
crlf bool
// section denotes which section this line belongs to.
section section
}
// Resolution indicates how to resolve a conflict
type Resolution struct {
// OldPath is the mapping of the path wrt to 'ours' OID
OldPath string `json:"old_path"`
// OldPath is the mapping of the path wrt to 'their' OID
NewPath string `json:"new_path"`
// Sections is a map which is used to denote which section to select
// for each conflict. Key is the sectionID, while the value is either
// "head" or "origin", which denotes the ours/theirs OIDs respectively.
Sections map[string]string `json:"sections"`
// Content is used when no sections are defined
Content string `json:"content"`
}
// Resolve is used to resolve conflicts for a given blob. It expects the blob
// to be provided as an io.Reader along with the resolutions for the provided
// blob. Clients can also use appendNewLine to have an additional new line appended
// to the end of the resolved buffer.
func Resolve(src io.Reader, ours, theirs git.ObjectID, path string, resolution Resolution, appendNewLine bool) (io.Reader, error) {
var (
// conflict markers, git-merge-tree(1) appends the tree OIDs to the markers
start = "<<<<<<< " + ours.String()
middle = "======="
end = ">>>>>>> " + theirs.String()
objIndex, oldIndex, newIndex uint = 0, 1, 1
currentSection section
bytesRead int
resolvedContent bytes.Buffer
s = bufio.NewScanner(src)
)
// When the paths are different, the conflicts contain the path names.
if resolution.OldPath != resolution.NewPath {
start = fmt.Sprintf("%s:%s", start, resolution.OldPath)
end = fmt.Sprintf("%s:%s", end, resolution.NewPath)
}
// allow for line scanning up to the file limit
s.Buffer(make([]byte, 4096), fileLimit)
s.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
defer func() { bytesRead += advance }()
if bytesRead >= fileLimit {
return 0, nil, ErrUnmergeableFile
}
// The remaining function is a modified version of
// bufio.ScanLines that does not consume carriage returns
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
// We have a full newline-terminated line.
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
var lines []line
for s.Scan() {
switch l, crlf := strings.CutSuffix(s.Text(), "\r"); l {
case start:
if currentSection != sectionNone {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrUnexpectedDelimiter)
}
currentSection = sectionNew
case middle:
if currentSection != sectionNew {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrUnexpectedDelimiter)
}
currentSection = sectionOld
case end:
if currentSection != sectionOld {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrUnexpectedDelimiter)
}
currentSection = sectionNone
default:
if len(l) > 0 && l[0] == '\\' {
currentSection = sectionNoNewline
lines = append(lines, line{
objIndex: objIndex,
oldIndex: oldIndex,
newIndex: newIndex,
payload: l,
crlf: crlf,
section: currentSection,
})
continue
}
lines = append(lines, line{
objIndex: objIndex,
oldIndex: oldIndex,
newIndex: newIndex,
payload: l,
crlf: crlf,
section: currentSection,
})
objIndex++
if currentSection != sectionNew {
oldIndex++
}
if currentSection != sectionOld {
newIndex++
}
}
}
if err := s.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrUnmergeableFile)
}
return &resolvedContent, err
}
if currentSection == sectionOld || currentSection == sectionNew {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrMissingEndDelimiter)
}
if bytesRead == 0 {
return &resolvedContent, fmt.Errorf("resolve: parse conflict for %q: %w", path, ErrUnmergeableFile) // typically a binary file
}
var sectionID string
if len(resolution.Sections) == 0 {
resolvedContent.WriteString(resolution.Content)
return &resolvedContent, nil
}
resolvedLines := make([]string, 0, len(lines))
for _, l := range lines {
if l.section == sectionNone {
sectionID = ""
resolvedLines = append(resolvedLines, l.payload)
continue
}
if sectionID == "" {
sectionID = fmt.Sprintf("%x_%d_%d", sha1.Sum([]byte(path)), l.oldIndex, l.newIndex)
}
r, ok := resolution.Sections[sectionID]
if !ok {
return nil, fmt.Errorf("Missing resolution for section ID: %s", sectionID)
}
switch r {
case head:
if l.section != sectionNew {
continue
}
case origin:
if l.section != sectionOld {
continue
}
default:
return nil, fmt.Errorf("Missing resolution for section ID: %s", sectionID)
}
resolvedLines = append(resolvedLines, l.payload)
}
lineBreak := "\n"
if len(lines) > 0 && lines[0].crlf {
lineBreak = "\r\n"
}
resolvedContent.WriteString(strings.Join(resolvedLines, lineBreak))
if appendNewLine {
resolvedContent.WriteString(lineBreak)
}
return &resolvedContent, nil
}