internal/git/pktline/pktline.go (141 lines of code) (raw):

// Package pktline implements utility functions for working with the Git // pkt-line format. See // https://git-scm.com/docs/protocol-common#_pkt_line_format package pktline import ( "bufio" "bytes" "fmt" "io" "strconv" "sync" ) const ( // MaxSidebandData is the maximum number of bytes that fits into one Git // pktline side-band-64k packet. MaxSidebandData = MaxPktSize - 5 // MaxPktSize is the maximum size of content of a Git pktline side-band-64k // packet, including size of length and band number // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216 MaxPktSize = 65520 ) // NewScanner returns a bufio.Scanner that splits on Git pktline boundaries func NewScanner(r io.Reader) *bufio.Scanner { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, MaxPktSize), MaxPktSize) scanner.Split(pktLineSplitter) return scanner } // Data returns the packet pkt without its length header. The length // header is not validated. Returns an empty slice when pkt is a magic packet such // as '0000'. func Data(pkt []byte) []byte { return pkt[4:] } // Payload returns the pktline's data. It verifies that the length header matches what we expect as // data. func Payload(pkt []byte) ([]byte, error) { if len(pkt) < 4 { return nil, fmt.Errorf("packet too small") } if IsFlush(pkt) { return nil, fmt.Errorf("flush packets do not have a payload") } lengthHeader := string(pkt[:4]) length, err := strconv.ParseUint(lengthHeader, 16, 16) if err != nil { return nil, fmt.Errorf("parsing length header %q: %w", lengthHeader, err) } if uint64(len(pkt)) != length { return nil, fmt.Errorf("packet length %d does not match header length %d", len(pkt), length) } return pkt[4:], nil } // IsFlush detects the special flush packet '0000' func IsFlush(pkt []byte) bool { return bytes.Equal(pkt, PktFlush()) } // WriteString writes a string with pkt-line framing func WriteString(w io.Writer, str string) (int, error) { pktLen := len(str) + 4 if pktLen > MaxPktSize { return 0, fmt.Errorf("string too large: %d bytes", len(str)) } _, err := fmt.Fprintf(w, "%04x%s", pktLen, str) return len(str), err } // WriteFlush writes a pkt flush packet. func WriteFlush(w io.Writer) error { _, err := w.Write(PktFlush()) return err } // WriteDelim writes a pkt delim packet. func WriteDelim(w io.Writer) error { _, err := w.Write(PktDelim()) return err } // PktDone returns the bytes for a "done" packet. func PktDone() []byte { return []byte("0009done\n") } // PktDelim returns the bytes for a "delim" packet. func PktDelim() []byte { return []byte("0001") } // PktFlush returns the bytes for a "flush" packet. func PktFlush() []byte { return []byte("0000") } func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) { if len(data) < 4 { if atEOF && len(data) > 0 { return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data) } return 0, nil, nil // want more data } // We have at least 4 bytes available so we can decode the 4-hex digit // length prefix of the packet line. pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0) if err != nil { return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %w", err) } // Cast is safe because we requested an int-size number from strconv.ParseInt pktLength := int(pktLength64) if pktLength < 0 || pktLength > MaxPktSize { return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength) } if pktLength < 4 { // Special case: magic empty packet 0000, 0001, 0002 or 0003. return 4, data[:4], nil } if len(data) < pktLength { // data contains incomplete packet if atEOF { return 0, nil, io.ErrUnexpectedEOF } return 0, nil, nil // want more data } return pktLength, data[:pktLength], nil } // SidebandWriter multiplexes byte streams into a single side-band-64k stream. type SidebandWriter struct { w io.Writer m sync.Mutex buf [MaxPktSize]byte // Use a buffer to coalesce header and payload into one write syscall } // NewSidebandWriter instantiates a new SidebandWriter. func NewSidebandWriter(w io.Writer) *SidebandWriter { return &SidebandWriter{w: w} } func (sw *SidebandWriter) writeBand(band byte, data []byte) (int, error) { sw.m.Lock() defer sw.m.Unlock() n := 0 for len(data) > 0 { const headerSize = 5 chunkSize := copy(sw.buf[headerSize:], data) header := chunkSize + headerSize copy(sw.buf[:4], fmt.Sprintf("%04x", header)) sw.buf[4] = band if _, err := sw.w.Write(sw.buf[:header]); err != nil { return n, err } data = data[chunkSize:] n += chunkSize } return n, nil } // Writer returns an io.Writer that writes into the multiplexed stream. // Writers for different bands can be used concurrently. func (sw *SidebandWriter) Writer(band byte) io.Writer { return writerFunc(func(p []byte) (int, error) { return sw.writeBand(band, p) }) } type writerFunc func([]byte) (int, error) func (wf writerFunc) Write(p []byte) (int, error) { return wf(p) } type invalidSidebandPacketError struct{ pkt string } func (err *invalidSidebandPacketError) Error() string { return fmt.Sprintf("invalid sideband packet: %q", err.pkt) } // EachSidebandPacket iterates over a side-band-64k pktline stream. For // each packet, it will call fn with the band ID and the packet. Fn must // not retain the packet. func EachSidebandPacket(r io.Reader, fn func(byte, []byte) error) error { scanner := NewScanner(r) for scanner.Scan() { data := Data(scanner.Bytes()) if len(data) == 0 { return &invalidSidebandPacketError{scanner.Text()} } if err := fn(data[0], data[1:]); err != nil { return err } } return scanner.Err() }