main.go (338 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package main import ( "flag" "fmt" "math" "os" "regexp" "strconv" "strings" "sync" "time" "doris-streamloader/loader" file "doris-streamloader/reader" "doris-streamloader/report" "doris-streamloader/utils" log "github.com/sirupsen/logrus" ) const ( fileBufferSize = 16 * 1024 * 1024 // 16MB bufferSize = 1 * 1024 * 1024 // 1MB defaultQueueSize = 100 defaultTimeout = 60 * 60 * 10 defaultBatchRows = 4096 defaultBatchBytes = 943718400 defaultMaxBytesPerTask = 9663676416 defaultReportDuration = 1 defaultMaxRetryTimes = 3 defaultRetryInterval = 60 defaultDiskThroughput = 800 // 800MB/s defaultStreamLoadThroughput = 100 // 100MB/s ) var ( sourceFilePaths string url string dbName string tableName string userName string password string compress bool header string headers map[string]string timeout int batchRows int batchBytes int maxRowsPerTask int maxBytesPerTask int debug bool workers int enableConcurrency bool diskThroughput int streamLoadThroughput int checkUTF8 bool reportDuration int retry string maxRetryTimes int retryInterval int retryInfo map[int]int showVersion bool queueSize int lineDelimiter byte = '\n' bufferPool = sync.Pool{ New: func() interface{} { return make([]byte, 0, bufferSize) }, } loadInfo *loader.LoadInfo loadResp *loader.Resp startTime time.Time ) func initFlags() { flag.StringVar(&sourceFilePaths, "source_file", "", "source file paths") flag.StringVar(&url, "url", "", "url") flag.StringVar(&dbName, "db", "", "db name") flag.StringVar(&tableName, "table", "", "table name") flag.StringVar(&userName, "u", "root", "username") flag.StringVar(&password, "p", "", "password") flag.StringVar(&header, "header", "", "header") flag.BoolVar(&compress, "compress", false, "compress") flag.IntVar(&timeout, "timeout", defaultTimeout, "connect/read/write timeout seconds for rw and wait fe/be header") // 10h flag.IntVar(&batchRows, "batch", defaultBatchRows, "batch row size") flag.IntVar(&batchBytes, "batch_byte", defaultBatchBytes, "batch byte size") flag.IntVar(&maxBytesPerTask, "max_byte_per_task", defaultMaxBytesPerTask, "max byte per task") // 100G flag.IntVar(&workers, "workers", 0, "workers") flag.IntVar(&diskThroughput, "disk_throughput", defaultDiskThroughput, "disk throughput") flag.IntVar(&streamLoadThroughput, "streamload_throughput", defaultStreamLoadThroughput, "estimate streamload throughput") flag.BoolVar(&checkUTF8, "check_utf8", true, "check utf8") flag.IntVar(&reportDuration, "report_duration", defaultReportDuration, "report duration") // 1s flag.StringVar(&retry, "auto_retry", "", "retry failure") flag.IntVar(&maxRetryTimes, "auto_retry_times", defaultMaxRetryTimes, "retry failure") flag.IntVar(&retryInterval, "auto_retry_interval", defaultRetryInterval, "retry failure") flag.BoolVar(&debug, "debug", false, "enable debug") flag.BoolVar(&showVersion, "version", false, "Display the version") flag.IntVar(&queueSize, "queue_size", defaultQueueSize, "memory queue size") flag.Parse() paramCheck() loadInfo = &loader.LoadInfo{ SourceFilePaths: sourceFilePaths, Url: url, DbName: dbName, TableName: tableName, UserName: userName, Password: password, Compress: compress, Headers: headers, Timeout: timeout, BatchRows: batchRows, BatchBytes: batchBytes, MaxBytesPerTask: maxBytesPerTask, Debug: debug, Workers: workers, DiskThroughput: diskThroughput, StreamLoadThroughput: streamLoadThroughput, CheckUTF8: checkUTF8, ReportDuration: reportDuration, NeedRetry: true, RetryTimes: maxRetryTimes, RetryInterval: retryInterval, } loadResp = &loader.Resp{ Status: "Success", TotalRows: 0, FailLoadRows: 0, LoadedRows: 0, FilteredRows: 0, UnselectedRows: 0, LoadBytes: 0, LoadTimeMs: 0, LoadFiles: []string{}, } logLevel := "info" // debug print flags if debug { logLevel = "debug" fmt.Println("source_file: ", sourceFilePaths) fmt.Println("url: ", url) fmt.Println("db: ", dbName) fmt.Println("table: ", tableName) fmt.Println("username: ", userName) fmt.Println("password: ", password) // print headers for k, v := range headers { fmt.Printf("header: %s:%s\n", k, v) } fmt.Println("compress: ", compress) fmt.Println("timeout: ", timeout) fmt.Println("batch_row_size: ", batchRows) fmt.Println("batch_byte_size: ", batchBytes) fmt.Println("max_rows_per_task: ", maxRowsPerTask) fmt.Println("max_bytes_per_task: ", maxBytesPerTask) fmt.Println("debug: ", debug) fmt.Println("workers: ", workers) fmt.Println("check_utf8: ", checkUTF8) fmt.Println("report_duration: ", reportDuration) fmt.Println("retry_info: ", retry) fmt.Println("retry_times: ", maxRetryTimes) fmt.Println("retry_interval: ", retryInterval) fmt.Println("queue_size: ", queueSize) } utils.InitLog(logLevel) } // Restore hex escape sequences like \xNN to their corresponding characters func restoreHexEscapes(s1 string) (string, error) { if s1 == `\n` { return "\n", nil } re := regexp.MustCompile(`\\x([0-9A-Fa-f]{2})`) return re.ReplaceAllStringFunc(s1, func(match string) string { hexValue := match[2:] // Remove the \x prefix decValue, err := strconv.ParseInt(hexValue, 16, 0) if err != nil { return match } return string(rune(decValue)) }), nil } //go:generate go run gen_version.go func paramCheck() { if showVersion { commitHash := GitCommit version := Version fmt.Printf("version %s, git commit %s\n", version, commitHash) os.Exit(0) } // split retry "a,b;c,d" into {a:b; c:d} retryInfo = make(map[int]int) if retry != "" { for _, v := range strings.Split(retry, ";") { kv := strings.Split(v, ",") workerIndex, err := strconv.ParseInt(kv[0], 20, 64) if err != nil { log.Errorf("bad retry info, err %v", err) os.Exit(1) } taskIndex, err := strconv.ParseInt(kv[1], 20, 64) if err != nil { log.Errorf("bad retry info, err %v", err) os.Exit(1) } retryInfo[int(workerIndex)] = int(taskIndex) } } // check url if url == "" { log.Errorf("url is empty") os.Exit(1) } // check source file path if sourceFilePaths == "" { log.Errorf("source file path is empty") os.Exit(1) } if dbName == "" { log.Errorf("db name is empty") os.Exit(1) } if tableName == "" { log.Errorf("table name is empty") os.Exit(1) } // split header "a:b?c:d" into {a:b, c:d} enableConcurrency = true if header != "" { headers = make(map[string]string) for _, v := range strings.Split(header, "?") { if v == "" { continue } kv := strings.Split(v, ":") if strings.ToLower(kv[0]) == "format" && strings.ToLower(kv[1]) != "csv" { enableConcurrency = false } if strings.ToLower(kv[0]) == "line_delimiter" { restored, err := restoreHexEscapes(kv[1]) if err != nil || len(restored) != 1 { log.Errorf("line_delimiter invalid: %s", kv[1]) os.Exit(1) } else { lineDelimiter = restored[0] } } if len(kv) > 2 { headers[kv[0]] = strings.Join(kv[1:], ":") } else { headers[kv[0]] = kv[1] } } } if timeout <= 0 { log.Warnf("timeout invalid: %d, replace with default value: %d", timeout, defaultTimeout) timeout = defaultTimeout } if batchRows <= 0 { log.Warnf("batchRows invalid: %d, replace with default value: %d", batchRows, defaultBatchRows) batchRows = defaultBatchRows } if batchBytes <= 0 { log.Warnf("batchBytes invalid: %d, replace with default value: %d", batchBytes, defaultBatchBytes) batchBytes = defaultBatchBytes } if reportDuration <= 0 { log.Warnf("reportDuration invalid: %d, replace with default value: %d", reportDuration, defaultReportDuration) reportDuration = defaultReportDuration } if maxBytesPerTask <= 0 { log.Warnf("maxBytesPerTask invalid: %d, replace with default value: %d", maxBytesPerTask, defaultMaxBytesPerTask) maxBytesPerTask = defaultMaxBytesPerTask } if maxRetryTimes < 0 { log.Warnf("maxRetryTimes invalid: %d, replace with default value: %d", maxRetryTimes, defaultMaxRetryTimes) maxRetryTimes = defaultMaxRetryTimes } if retryInterval < 0 { log.Warnf("retryInterval invalid: %d, replace with default value: %d", retryInterval, defaultRetryInterval) retryInterval = defaultRetryInterval } if queueSize <= 0 { log.Warnf("queueSize invalid: %d, replace with default value: %d", queueSize, defaultQueueSize) queueSize = defaultQueueSize } } func calculateAndCheckWorkers(reader *file.FileReader, size int64) { if workers > 0 { return } if !enableConcurrency { loadInfo.Workers = 1 workers = 1 return } ratio := float64(size) / float64(maxBytesPerTask) tmpWorkers := 0 if ratio > 0.0 && ratio <= 0.001 { tmpWorkers = 1 } else if ratio > 0.001 && ratio <= 0.01 { tmpWorkers = 2 } else if ratio > 0.01 && ratio <= 0.1 { tmpWorkers = 4 } else { tmpWorkers = 8 } tmpWorkers = int(math.Min(float64(tmpWorkers), float64(diskThroughput)/float64(streamLoadThroughput))) log.Infof("worker number is %d, which is <= 0, trigger automatic inference. Final worker number is %d", workers, tmpWorkers) workers = tmpWorkers loadInfo.Workers = workers } func createStreamLoad(report *report.Reporter, queues []chan []byte) *loader.StreamLoad { // create StreamLoadOption && StreamLoad from flags streamLoadOption := loader.StreamLoadOption{ Compress: compress, CheckUTF8: checkUTF8, Timeout: timeout, } return loader.NewStreamLoad(url, dbName, tableName, userName, password, headers, queues, &bufferPool, streamLoadOption, report, loadResp) } func createQueues(queues *[]chan []byte) { for i := 0; i < workers; i++ { *queues = append(*queues, make(chan []byte, queueSize)) } } func main() { initFlags() retryCount := 0 for { // create queue by worker size var queues []chan []byte // create file reader fileSize := int64(0) reader := file.NewFileReader(sourceFilePaths, batchRows, batchBytes, fileBufferSize, &queues, &bufferPool, &fileSize) calculateAndCheckWorkers(reader, fileSize) createQueues(&queues) reporter := report.NewReporter(reportDuration, fileSize, uint64(workers)) streamLoad := createStreamLoad(reporter, queues) if retryCount == 0 { startTime = time.Now() } maxRowsPerTask = math.MinInt32 streamLoad.Load(workers, maxRowsPerTask, maxBytesPerTask, &retryInfo) reporter.Report() defer reporter.CloseWait() reader.Read(reporter, workers, maxBytesPerTask, &retryInfo, loadResp, retryCount, lineDelimiter) reader.Close() streamLoad.Wait(loadInfo, retryCount, &retryInfo, startTime) if !loadInfo.NeedRetry || retryCount >= loadInfo.RetryTimes-1 { break } time.Sleep(time.Duration(loadInfo.RetryInterval) * time.Second) retryCount++ } }