src/scip-lib/scanner/scan.go (445 lines of code) (raw):

package scanner import ( "bytes" "errors" "fmt" "io" "os" "path/filepath" "runtime" "strings" "sync" "github.com/sourcegraph/scip/bindings/go/scip" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" ) // IndexScanner defines the ways an index can be read. type IndexScanner interface { ScanIndexFile(file string) error ScanIndexFolder(folder string, parallel bool) error ScanIndexReader(reader ScipReader) error ScanDocumentReader(relativeDocPath string, reader ScipReader) error } // ScipReader defines the minimum interface required for a reader // to be passed into the IndexScanner type ScipReader interface { io.Reader io.Seeker } // IndexScannerImpl implements the IndexScanner interface. // A consumer can instantiate this implementation and implement // any subset of the Match* methods, and (optionally) the matching // Visit* method. // After instantiating the struct, the consumer SHALL call .InitBuffers() // Performance Considerations: // 1. When multiple IndexScannerImpl's are expected to be instantiated at the same time, // a shared BufferPool can be provided to the struct's Pool property. // This improves performance and memory consumption. // 2. The Match* methods use []byte's for performance reasons. // 3. The Visit* method will only be called if: the Visit* method is defined AND // the Match* function returns true // 4. Should the consumer only require the monikers (symbol IDs), using the Visit* method // and returning false is the recommended method. // 5. Consider the size of the indices, there's an order of magnitude more occurrences // than symbols, and more symbols than documents. So if a consumer Matches & Visits every // occurrence, or document the performance gains over a traditional proto parser will // be limited. type IndexScannerImpl struct { MatchDocumentPath func(string) bool MatchSymbol func([]byte) bool MatchOccurrence func([]byte) bool // TODO(IDE-1335): Implement MatchExternalSymbol // MatchExternalSymbol func([]byte) bool VisitDocument func(*scip.Document) VisitSymbol func(string, *scip.SymbolInformation) VisitOccurrence func(string, *scip.Occurrence) VisitExternalSymbol func(*scip.SymbolInformation) Pool *BufferPool MaxConcurrency int } var _ IndexScanner = &IndexScannerImpl{} // See https://protobuf.dev/programming-guides/encoding/#varints const maxVarintBytes = 10 const ( indexMetadataFieldNumber = 1 indexDocumentsFieldNumber = 2 indexExternalSymbolsFieldNumber = 3 metadataProjectRootFieldNumber = 3 documentRelativePathFieldNumber = 1 documentOccurrencesFieldNumber = 2 documentSymbolsFieldNumber = 3 documentLanguageFieldNumber = 4 symbolInformationSymbolField = 1 occurrenceSymbolField = 2 ) // InitBuffers initializes the relevant buffers, and pools func (is *IndexScannerImpl) InitBuffers() { if is.Pool == nil { is.Pool = NewBufferPool(1024, 12) } } // ScanIndexReader incrementally processes an index by reading input from the io.Reader // It skips over any other fields defined in the SCIP index. func (is *IndexScannerImpl) ScanIndexReader(r ScipReader) error { if is.Pool == nil { is.InitBuffers() } document := scip.Document{} for { fieldNumber, fieldType, err := is.consumeBytesFieldMeta(r) if err == io.EOF { return nil } if err != nil { return errors.Join(fmt.Errorf("failed to consume field"), err) } if fieldNumber != indexDocumentsFieldNumber { err := is.skipField(r) if err != nil { return errors.Join(fmt.Errorf("failed to skip %q", indexFieldName(fieldNumber)), err) } continue } dataLen, err := is.consumeLen(r, fieldType) if err != nil { return errors.Join(fmt.Errorf("failed to consume length"), err) } dataBuf := is.Pool.Get(dataLen) err = is.consumeFieldData(r, dataLen, dataBuf) if err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to consume %q data", indexFieldName(fieldNumber)), err) } switch fieldNumber { case indexDocumentsFieldNumber: docPath, err := is.parseDocumentPath(dataBuf) if err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to parse document path"), err) } // Unmarshall the doc when all conditions match if is.MatchDocumentPath != nil && is.MatchDocumentPath(docPath) && is.VisitDocument != nil { if err := proto.Unmarshal(dataBuf, &document); err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to read document"), err) } is.VisitDocument(&document) } // Individually scna through the document to parse symbols/occurrences docReader := bytes.NewReader(dataBuf) if err := is.ScanDocumentReader(docPath, docReader); err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to scan document"), err) } } is.Pool.Put(dataBuf) } } // ScanDocumentReader scans through an individual document. This requires the path of the // document only to pass into VisitSymbol/Occurrence methods. func (is *IndexScannerImpl) ScanDocumentReader(docPath string, r ScipReader) error { symbolInfo := scip.SymbolInformation{} occurrence := scip.Occurrence{} for { fieldNumber, fieldType, err := is.consumeBytesFieldMeta(r) if err == io.EOF { return nil } if err != nil { return errors.Join(fmt.Errorf("failed to consume field"), err) } if !is.shouldParseDocumentField(fieldNumber) { err := is.skipField(r) if err != nil { return errors.Join(fmt.Errorf("failed to skip %q", documentFieldName(fieldNumber)), err) } continue } dataLen, err := is.consumeLen(r, fieldType) if err != nil { return errors.Join(fmt.Errorf("failed to consume length"), err) } dataBuf := is.Pool.Get(dataLen) err = is.consumeFieldData(r, dataLen, dataBuf) if err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to consume %q data", documentFieldName(fieldNumber)), err) } switch fieldNumber { case documentSymbolsFieldNumber: syByte, err := is.parseSymbolMoniker(dataBuf) if err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to parse symbol moniker"), err) } match := is.MatchSymbol(syByte) is.Pool.Put(syByte) if match && is.VisitSymbol != nil { if err := proto.Unmarshal(dataBuf, &symbolInfo); err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to read %q", documentFieldName(fieldNumber)), err) } is.VisitSymbol(docPath, &symbolInfo) } case documentOccurrencesFieldNumber: syByte, err := is.parseOccurrenceSymbol(dataBuf) if err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to parse symbol moniker"), err) } match := is.MatchOccurrence(syByte) is.Pool.Put(syByte) if match && is.VisitOccurrence != nil { if err := proto.Unmarshal(dataBuf, &occurrence); err != nil { is.Pool.Put(dataBuf) return errors.Join(fmt.Errorf("failed to read %q", documentFieldName(fieldNumber)), err) } is.VisitOccurrence(docPath, &occurrence) } } is.Pool.Put(dataBuf) } } // ScanIndexFile scans an individual index file. func (is *IndexScannerImpl) ScanIndexFile(path string) error { reader, err := os.Open(path) if err != nil { return errors.Join(fmt.Errorf("failed to open file"), err) } defer reader.Close() return is.ScanIndexReader(reader) } // ScanIndexFolder scans an entire folder's indices. Good to use when the consumer // needs to scan through a bunch of indices for specific symbols. func (is *IndexScannerImpl) ScanIndexFolder(path string, parallel bool) error { entries, err := os.ReadDir(path) if err != nil { return errors.Join(fmt.Errorf("failed to read directory"), err) } var wg sync.WaitGroup errChan := make(chan error, len(entries)) processFile := func(filePath string) { defer wg.Done() // Skip non-SCIP files if filepath.Ext(filePath) != ".scip" { return } if err := is.ScanIndexFile(filePath); err != nil { errChan <- err } } if parallel { maxWorkers := runtime.NumCPU() if is.MaxConcurrency > 0 { maxWorkers = is.MaxConcurrency } sem := make(chan struct{}, maxWorkers) for _, entry := range entries { if entry.IsDir() { continue } wg.Add(1) fullPath := filepath.Join(path, entry.Name()) sem <- struct{}{} go func(path string) { processFile(path) <-sem }(fullPath) } } else { for _, entry := range entries { if entry.IsDir() { continue } wg.Add(1) fullPath := filepath.Join(path, entry.Name()) processFile(fullPath) } } wg.Wait() close(errChan) var errs []error for err := range errChan { errs = append(errs, err) } if len(errs) > 0 { errStrings := make([]string, len(errs)) for i, err := range errs { errStrings[i] = err.Error() } return errors.New(strings.Join(errStrings, "\n")) } return nil } // parseDocumentPath reads the relative path of the current document from the blob func (is *IndexScannerImpl) parseDocumentPath(docData []byte) (string, error) { r := bytes.NewReader(docData) for { fieldNumber, fieldType, err := is.consumeBytesFieldMeta(r) if err == io.EOF { return "", nil } if err != nil { return "", errors.Join(fmt.Errorf("failed to consume field"), err) } if fieldNumber != documentRelativePathFieldNumber { err := is.skipField(r) if err != nil { return "", errors.Join(fmt.Errorf("failed to skip %q", documentFieldName(fieldNumber)), err) } continue } dataLen, err := is.consumeLen(r, fieldType) if err != nil { return "", errors.Join(fmt.Errorf("failed to consume length"), err) } dataBuf := is.Pool.Get(dataLen) defer is.Pool.Put(dataBuf) err = is.consumeFieldData(r, dataLen, dataBuf) if err != nil { return "", errors.Join(fmt.Errorf("failed to consume %q data", documentFieldName(fieldNumber)), err) } // Return immediately after reading the document path dataStr := string(dataBuf) return dataStr, nil } } // parseSymbolMoniker reads the symbol moniker from a symbolData blob // dataBuf is directly returned from pool and should be returned by consumer. func (is *IndexScannerImpl) parseSymbolMoniker(symbolData []byte) ([]byte, error) { r := bytes.NewReader(symbolData) for { fieldNumber, fieldType, err := is.consumeBytesFieldMeta(r) if err == io.EOF { return nil, nil } if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume field"), err) } if fieldNumber != symbolInformationSymbolField { err := is.skipField(r) if err != nil { return nil, errors.Join(fmt.Errorf("failed to skip field %d in SymbolInfo", fieldNumber), err) } continue } dataLen, err := is.consumeLen(r, fieldType) if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume length"), err) } dataBuf := is.Pool.Get(dataLen) err = is.consumeFieldData(r, dataLen, dataBuf) if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume field %d data in SymbolInfo", fieldNumber), err) } // Return immediately after reading the symbol string return dataBuf, nil } } // parseOccurrenceSymbol reads the symbol moniker from an occurrence blob // dataBuf is directly returned from pool and should be returned by consumer. func (is *IndexScannerImpl) parseOccurrenceSymbol(occData []byte) ([]byte, error) { r := bytes.NewReader(occData) for { fieldNumber, fieldType, err := is.consumeBytesFieldMeta(r) if err == io.EOF { return nil, nil } if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume field"), err) } if fieldNumber != occurrenceSymbolField { err := is.skipField(r) if err != nil { return nil, errors.Join(fmt.Errorf("failed to skip field %d in Occurrence", fieldNumber), err) } continue } dataLen, err := is.consumeLen(r, fieldType) if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume length"), err) } dataBuf := is.Pool.Get(dataLen) err = is.consumeFieldData(r, dataLen, dataBuf) if err != nil { return nil, errors.Join(fmt.Errorf("failed to consume field %d data in Occurrence", fieldNumber), err) } // Return immediately after reading the symbol string return dataBuf, nil } } // shouldParseDocumentField returns true if a field in a document should be parsed func (is *IndexScannerImpl) shouldParseDocumentField(fieldNumber protowire.Number) bool { switch fieldNumber { case documentSymbolsFieldNumber: return is.MatchSymbol != nil case documentOccurrencesFieldNumber: return is.MatchOccurrence != nil } return false } // consumeBytesFieldMeta reads a field from the reader and returns the field number and type. func (is *IndexScannerImpl) consumeBytesFieldMeta(r ScipReader) (protowire.Number, protowire.Type, error) { tagBuf := is.Pool.Get(1) defer is.Pool.Put(tagBuf) numRead, err := r.Read(tagBuf) if err == io.EOF { return 0, 0, io.EOF } if err != nil { return 0, 0, errors.Join(fmt.Errorf("failed to read from index reader"), err) } if numRead == 0 { return 0, 0, errors.New("read 0 bytes from index") } fieldNumber, fieldType, errCode := protowire.ConsumeTag(tagBuf) if errCode < 0 { return 0, 0, errors.Join(fmt.Errorf("failed to consume tag"), protowire.ParseError(errCode)) } if fieldType != protowire.BytesType { return 0, 0, fmt.Errorf("expected bytes type for field %d", fieldNumber) } return fieldNumber, fieldType, nil } // skipField skips over a field in the reader func (is *IndexScannerImpl) skipField(r ScipReader) error { lenBuf := is.Pool.Get(maxVarintBytes)[:0] dataLen, err := readVarint(r, &lenBuf) is.Pool.Put(lenBuf) if err != nil { return errors.Join(fmt.Errorf("failed to read length"), err) } _, err = r.Seek(int64(dataLen), io.SeekCurrent) if err != nil { return errors.Join(fmt.Errorf("failed to skip field"), err) } return nil } // consumeLen reads the length of a field from the reader func (is *IndexScannerImpl) consumeLen(r ScipReader, fieldType protowire.Type) (int, error) { if fieldType != protowire.BytesType { return 0, errors.New("expected LEN type tag") } lenBuf := is.Pool.Get(maxVarintBytes)[:0] dataLen, err := readVarint(r, &lenBuf) is.Pool.Put(lenBuf) if err != nil { return 0, errors.Join(fmt.Errorf("failed to read length"), err) } return int(dataLen), err } // readAndParseField reads a field from the reader and parses it. func (is *IndexScannerImpl) consumeFieldData(r ScipReader, dataLen int, dataBuf []byte) error { if dataLen > 0 { numRead, err := r.Read(dataBuf) if err != nil { return errors.Join(fmt.Errorf("failed to read data"), err) } if numRead != dataLen { return fmt.Errorf("expected to read %d bytes based on LEN but read %d bytes", dataLen, numRead) } } return nil } // readVarint attempts to read a varint, using scratchBuf for temporary storage // scratchBuf should be able to accommodate any varint size // based on its capacity, and be cleared before readVarint is called. // Varints < 128 fit in 1 byte, which means 4 bits are available for field // numbers. The SCIP types have less than 15 fields, so the tag will fit in 1 byte. func readVarint(r ScipReader, scratchBuf *[]byte) (uint64, error) { nextByteBuf := make([]byte, 1, 1) for i := 0; i < cap(*scratchBuf); i++ { numRead, err := r.Read(nextByteBuf) if err != nil { return 0, errors.Join(fmt.Errorf("failed to read %d-th byte of Varint", i), err) } if numRead == 0 { return 0, fmt.Errorf("failed to read %d-th byte of Varint", i) } nextByte := nextByteBuf[0] *scratchBuf = append(*scratchBuf, nextByte) if nextByte <= 127 { // https://protobuf.dev/programming-guides/encoding/#varints // Continuation bit is not set, so Varint must've ended break } } value, errCode := protowire.ConsumeVarint(*scratchBuf) if errCode < 0 { return value, protowire.ParseError(errCode) } return value, nil } func indexFieldName(i protowire.Number) string { if i == indexMetadataFieldNumber { return "metadata" } else if i == indexDocumentsFieldNumber { return "documents" } else if i == indexExternalSymbolsFieldNumber { return "external_symbols" } return "<unknown>" } func documentFieldName(i protowire.Number) string { if i == documentRelativePathFieldNumber { return "relative_path" } else if i == documentOccurrencesFieldNumber { return "occurrences" } else if i == documentSymbolsFieldNumber { return "symbols" } else if i == documentLanguageFieldNumber { return "language" } return "<unknown>" } func metadataFieldName(i protowire.Number) string { if i == metadataProjectRootFieldNumber { return "project_root" } return "<unknown>" }