pkg/git/stack_struct.go (293 lines of code) (raw):
package git
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
"gitlab.com/gitlab-org/cli/internal/config"
)
type StackRef struct {
Prev string `json:"prev"`
Branch string `json:"branch"`
SHA string `json:"sha"`
Next string `json:"next"`
MR string `json:"mr"`
Description string `json:"description"`
}
// Stack represents a stacked diff data structure.
// Refs are structured as a doubly linked list where
// the links are identified with the StackRef.Prev
// and StackRef.Next fields.
// The StackRef.SHA is the id that the former two
// fields can point to.
// All stacks must be created with GatherStackRefs
// which validates the stack for consistency.
type Stack struct {
Title string
Refs map[string]StackRef
}
func (s Stack) Empty() bool { return len(s.Refs) == 0 }
func (s *Stack) RemoveRef(ref StackRef, gr GitRunner) error {
if ref.IsFirst() && ref.IsLast() {
// this is the only ref, so just remove it
err := DeleteStackRefFile(s.Title, ref)
delete(s.Refs, ref.SHA)
if err != nil {
return fmt.Errorf("could not delete reference file %v:", err)
}
return nil
}
err := s.adjustAdjacentRefs(ref)
if err != nil {
return fmt.Errorf("error adjusting next reference %v:", err)
}
err = DeleteStackRefFile(s.Title, ref)
if err != nil {
return fmt.Errorf("could not delete reference file %v:", err)
}
err = s.RemoveBranch(ref, gr)
if err != nil {
return fmt.Errorf("could not remove branch %v:", err)
}
delete(s.Refs, ref.SHA)
return nil
}
func (s *Stack) RemoveBranch(ref StackRef, gr GitRunner) error {
var branch string
var err error
if ref.IsFirst() {
branch, err = s.BaseBranch(gr)
if err != nil {
return err
}
} else {
branch = s.Refs[ref.Prev].Branch
}
err = CheckoutBranch(branch, gr)
if err != nil {
return err
}
err = DeleteLocalBranch(ref.Branch, gr)
if err != nil {
return err
}
return nil
}
func (s *Stack) adjustAdjacentRefs(ref StackRef) error {
refs := s.Refs
if ref.Prev != "" {
prev := refs[ref.Prev]
delete(refs, ref.Prev)
prev.Next = ref.Next
refs[ref.Prev] = prev
err := UpdateStackRefFile(s.Title, prev)
if err != nil {
return fmt.Errorf("could not update reference file %v:", err)
}
}
if ref.Next != "" {
next := refs[ref.Next]
delete(refs, ref.Next)
next.Prev = ref.Prev
refs[ref.Next] = next
err := UpdateStackRefFile(s.Title, next)
if err != nil {
return fmt.Errorf("could not update reference file %v:", err)
}
}
return nil
}
func (s *Stack) IndexAt(ref StackRef) int {
for i, r := range s.Iter2() {
if r == ref {
return i
}
}
return -1
}
func (s *Stack) Last() StackRef {
if s.Empty() {
return StackRef{}
}
for _, ref := range s.Refs {
if ref.IsLast() {
return ref
}
}
// All Stacks should be created with GatherStackRefs which validates the Stack consistency.
panic(errors.New("can't find the last ref in the chain. Data might be corrupted."))
}
func (s *Stack) First() StackRef {
if s.Empty() {
return StackRef{}
}
for _, ref := range s.Refs {
if ref.IsFirst() {
return ref
}
}
// All Stacks should be created with GatherStackRefs which validates the Stack consistency.
panic(errors.New("can't find the first ref in the chain. Data might be corrupted."))
}
// Iter returns an iterator to range from the first to the last ref in the stack.
func (s *Stack) Iter() iter.Seq[StackRef] {
return func(yield func(StackRef) bool) {
ref := s.First()
for !ref.Empty() {
if !yield(ref) {
return
}
ref = s.Refs[ref.Next]
}
}
}
func (s *Stack) Branches() (branches []string) {
for ref := range s.Iter() {
branches = append(branches, ref.Branch)
}
return
}
// Iter2 returns an iterator like Iter, but includes an index
func (s *Stack) Iter2() iter.Seq2[int, StackRef] {
return func(yield func(int, StackRef) bool) {
ref := s.First()
i := 0
for !ref.Empty() {
if !yield(i, ref) {
return
}
i++
ref = s.Refs[ref.Next]
}
}
}
func (s *Stack) BaseBranch(gr GitRunner) (branch string, err error) {
root, err := StackRootDir(s.Title)
if err != nil {
return "", fmt.Errorf("could not determine stack root: %w", err)
}
filename := filepath.Join(root, BaseBranchFile)
// we do have a base branch in the metadata
fileInfo, err := os.Stat(filename)
if err == nil && !fileInfo.IsDir() {
trimmed, err := config.TrimmedFileContents(filename)
if err != nil {
return "", fmt.Errorf("could not read base branch file: %w", err)
}
return trimmed, nil
}
// if there's an error reading the file, show that.
// it's ok if doesn't exist yet, however.
if err != nil && !errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("could not access base branch file: %w", err)
}
// no metadata file - lets try to get it from git
defBranchOutput, err := gr.Git("remote", "show", DefaultRemote)
if err != nil {
return "", fmt.Errorf("could not get remote data: %w", err)
}
branch, err = ParseDefaultBranch([]byte(defBranchOutput))
if err != nil {
return "", fmt.Errorf("could not parse default branch from remote data: %w", err)
}
return branch, nil
}
func AddStackBaseBranch(title string, branch string) error {
root, err := StackRootDir(title)
if err != nil {
return fmt.Errorf("could not determine stack root: %w", err)
}
filename := filepath.Join(root, BaseBranchFile)
_, err = os.Create(filename)
if err != nil {
return err
}
data := []byte(branch)
err = os.WriteFile(filename, data, 0o644)
if err != nil {
return fmt.Errorf("error adding branch metadata file %v: %v", filename, err)
}
return nil
}
func GatherStackRefs(title string) (Stack, error) {
stack := Stack{Title: title}
stack.Refs = make(map[string]StackRef)
root, err := StackRootDir(title)
if err != nil {
return stack, err
}
err = filepath.WalkDir(root, func(dir string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
// read files in the stacked ref directory
// TODO: this may be quicker if we introduce a package
// https://github.com/bmatcuk/doublestar
if filepath.Ext(d.Name()) == ".json" {
data, err := os.ReadFile(dir)
if err != nil {
return err
}
// marshal them into our StackRef type
stackRef := StackRef{}
err = json.Unmarshal(data, &stackRef)
if err != nil {
return err
}
stack.Refs[stackRef.SHA] = stackRef
}
return nil
})
if err != nil {
if os.IsNotExist(err) { // there might not be any refs yet, this is ok.
return stack, nil
} else {
return stack, err
}
}
err = validateStackRefs(stack)
if err != nil {
return Stack{}, err
}
return stack, nil
}
func validateStackRefs(s Stack) error {
endRefs := 0
startRefs := 0
if s.Empty() {
// empty stacks are okay.
return nil
}
for _, ref := range s.Refs {
if ref.IsFirst() {
startRefs++
}
if ref.IsLast() {
endRefs++
}
if endRefs > 1 || startRefs > 1 {
return errors.New("More than one end or start ref detected. Data might be corrupted.")
}
}
if startRefs != 1 {
return errors.New("expected exactly one start ref. Data might be corrupted.")
}
if endRefs != 1 {
return errors.New("expected exactly one end ref. Data might be corrupted.")
}
return nil
}
func CurrentStackRefFromCurrentBranch(title string) (StackRef, error) {
stack, err := GatherStackRefs(title)
if err != nil {
return StackRef{}, err
}
branch, err := CurrentBranch()
if err != nil {
return StackRef{}, err
}
return stack.RefFromBranch(branch)
}
func (s Stack) RefFromBranch(branch string) (StackRef, error) {
for ref := range s.Iter() {
if ref.Branch == branch {
return ref, nil
}
}
return StackRef{}, errors.New("Could not find stack ref for branch: " + branch)
}
// Empty returns true if the stack ref does not have an associated SHA (commit).
// This indicates that the StackRef is invalid.
func (r StackRef) Empty() bool { return r.SHA == "" }
// IsFirst returns true if the stack ref is the first of the stack.
// A stack ref is considered the first if it does not reference any previous ref.
func (r StackRef) IsFirst() bool { return r.Prev == "" }
// IsLast returns true if the stack ref is the last of the stack.
// A stack ref is considered the last if it does not reference any next ref.
func (r StackRef) IsLast() bool { return r.Next == "" }
// Subject returns the stack ref description suitable as commit Subject
// and for other in space limited places.
// It only takes the first line of the description into account
// and truncates it to 72 characters.
func (r StackRef) Subject() string {
ls := strings.SplitN(r.Description, "\n", 1)
if len(ls[0]) <= 72 {
return ls[0]
}
return ls[0][:69] + "..."
}