internal/command/root.go (128 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more agreements. // Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. package command import ( "context" "fmt" "os" "os/signal" "strings" "time" "github.com/spf13/cobra" "github.com/spf13/pflag" "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sys/unix" "github.com/elastic/go-concert/ctxtool/osctx" "github.com/elastic/go-concert/timed" "github.com/elastic/stream/internal/httpserver" "github.com/elastic/stream/internal/log" "github.com/elastic/stream/internal/output" // Register outputs. _ "github.com/elastic/stream/internal/output/azureblobstorage" _ "github.com/elastic/stream/internal/output/azureeventhub" _ "github.com/elastic/stream/internal/output/gcppubsub" _ "github.com/elastic/stream/internal/output/gcs" _ "github.com/elastic/stream/internal/output/kafka" _ "github.com/elastic/stream/internal/output/lumberjack" _ "github.com/elastic/stream/internal/output/tcp" _ "github.com/elastic/stream/internal/output/tls" _ "github.com/elastic/stream/internal/output/udp" _ "github.com/elastic/stream/internal/output/webhook" ) func Execute() error { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) ctx, cancel := context.WithCancel(context.Background()) go func() { <-c cancel() }() return ExecuteContext(ctx) } func ExecuteContext(ctx context.Context) error { logger, err := log.NewLogger() if err != nil { return nil } rootCmd := &cobra.Command{Use: "stream", SilenceUsage: true} // Global flags. var opts output.Options rootCmd.PersistentFlags().StringVar(&opts.Addr, "addr", "", "destination address") rootCmd.PersistentFlags().DurationVar(&opts.Delay, "delay", 0, "delay start after start-signal") rootCmd.PersistentFlags().StringVarP(&opts.Protocol, "protocol", "p", "tcp", "protocol ("+strings.Join(output.Available(), "/")+")") rootCmd.PersistentFlags().IntVar(&opts.Retries, "retry", 10, "connection retry attempts for tcp based protocols") rootCmd.PersistentFlags().StringVarP(&opts.StartSignal, "start-signal", "s", "", "wait for start signal") rootCmd.PersistentFlags().BoolVar(&opts.InsecureTLS, "insecure", false, "disable tls verification") rootCmd.PersistentFlags().IntVar(&opts.RateLimit, "rate-limit", 500*1024, "bytes per second rate limit for UDP output") rootCmd.PersistentFlags().IntVar(&opts.MaxLogLineSize, "max-log-line-size", 500*1024, "max size of a single log line in bytes") // Webhook output flags. rootCmd.PersistentFlags().StringVar(&opts.WebhookOptions.ContentType, "webhook-content-type", "application/json", "webhook Content-Type") rootCmd.PersistentFlags().StringArrayVar(&opts.WebhookOptions.Headers, "webhook-header", nil, "webhook header to add to request (e.g. Header=Value)") rootCmd.PersistentFlags().StringVar(&opts.WebhookOptions.Password, "webhook-password", "", "webhook password for basic authentication") rootCmd.PersistentFlags().StringVar(&opts.WebhookOptions.Username, "webhook-username", "", "webhook username for basic authentication") rootCmd.PersistentFlags().DurationVar(&opts.WebhookOptions.Timeout, "webhook-timeout", time.Second, "webhook request timeout (zero is no timeout)") // GCP Pubsub output flags. rootCmd.PersistentFlags().StringVar(&opts.GCPPubsubOptions.Project, "gcppubsub-project", "test", "GCP Pubsub project name") rootCmd.PersistentFlags().StringVar(&opts.GCPPubsubOptions.Topic, "gcppubsub-topic", "topic", "GCP Pubsub topic name") rootCmd.PersistentFlags().StringVar(&opts.GCPPubsubOptions.Subscription, "gcppubsub-subscription", "subscription", "GCP Pubsub subscription name") rootCmd.PersistentFlags().BoolVar(&opts.GCPPubsubOptions.Clear, "gcppubsub-clear", true, "GCP Pubsub clear flag") // Azure BlobStorage output flags. rootCmd.PersistentFlags().StringVar(&opts.AzureBlobStorageOptions.Container, "azure-blob-storage-container", "testcontainer", "Azure Blob Storage container name") rootCmd.PersistentFlags().StringVar(&opts.AzureBlobStorageOptions.Blob, "azure-blob-storage-blob", "testblob", "Azure Blob Storage blob name") rootCmd.PersistentFlags().StringVar(&opts.AzureBlobStorageOptions.Port, "azure-blob-storage-port", "10000", "HTTP port used to connect to the blob storage, used for emulators and CI") // Azure EventHub output flags. rootCmd.PersistentFlags().StringVar(&opts.AzureEventHubOptions.FullyQualifiedNamespace, "azure-event-hub-namespace", "myeventhub.servicebus.windows.net", "Azure Eventhub namespace") rootCmd.PersistentFlags().StringVar(&opts.AzureEventHubOptions.EventHubName, "azure-event-hub-name", "test-eventhub-seis", "Azure Eventhub name") rootCmd.PersistentFlags().StringVar(&opts.AzureEventHubOptions.ConnectionString, "azure-event-hub-connection-string", "connectionstring", "Azure Eventhub connection string") // Kafka Pubsub output flags. rootCmd.PersistentFlags().StringVar(&opts.KafkaOptions.Topic, "kafka-topic", "test", "Kafka topic name") // GCS output flags. rootCmd.PersistentFlags().StringVar(&opts.GcsOptions.Bucket, "gcs-bucket", "testbucket", "GCS Bucket name") rootCmd.PersistentFlags().StringVar(&opts.GcsOptions.Object, "gcs-object", "testobject", "GCS Object name") rootCmd.PersistentFlags().StringVar(&opts.GcsOptions.ObjectContentType, "gcs-content-type", "application/json", "The Content type of the object to be uploaded to GCS.") rootCmd.PersistentFlags().StringVar(&opts.GcsOptions.ProjectID, "gcs-projectid", "testproject", "GCS Project name") // Lumberjack output flags. rootCmd.PersistentFlags().BoolVar(&opts.LumberjackOptions.ParseJSON, "lumberjack-parse-json", false, "Parse the input data as JSON and send the structured data as a Lumberjack batch.") // Sub-commands. rootCmd.AddCommand(newLogRunner(&opts, logger)) rootCmd.AddCommand(newPCAPRunner(&opts, logger)) httpOpts := httpserver.Options{Options: &opts} httpCommand := newHTTPServerRunner(&httpOpts, logger) httpCommand.PersistentFlags().DurationVar(&httpOpts.ReadTimeout, "read-timeout", 5*time.Second, "HTTP Server read timeout") httpCommand.PersistentFlags().DurationVar(&httpOpts.WriteTimeout, "write-timeout", 5*time.Second, "HTTP Server write timeout") httpCommand.PersistentFlags().StringVar(&httpOpts.TLSCertificate, "tls-cert", "", "Path to the TLS certificate") httpCommand.PersistentFlags().StringVar(&httpOpts.TLSKey, "tls-key", "", "Path to the TLS key file") httpCommand.PersistentFlags().StringVar(&httpOpts.ConfigPath, "config", "", "Path to the config file") httpCommand.PersistentFlags().BoolVar(&httpOpts.ExitOnUnmatchedRule, "exit-on-unmatched-rule", false, "If set to true it will exit if no rule matches a request") rootCmd.AddCommand(httpCommand) rootCmd.AddCommand(versionCmd) // Add common start-up delay logic. rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { return multierr.Combine( waitForStartSignal(&opts, cmd.Context(), logger), waitForDelay(&opts, cmd.Context(), logger), ) } // Automatically set flags based on environment variables. rootCmd.PersistentFlags().VisitAll(setFlagFromEnv) return rootCmd.ExecuteContext(ctx) } func waitForStartSignal(opts *output.Options, parent context.Context, logger *zap.Logger) error { if opts.StartSignal == "" { return nil } num := unix.SignalNum(opts.StartSignal) if num == 0 { return fmt.Errorf("unknown signal %v", opts.StartSignal) } // Wait for the signal or the command context to be done. logger.Sugar().Infow("Waiting for signal.", "start-signal", opts.StartSignal) startCtx, _ := osctx.WithSignal(parent, os.Signal(num)) <-startCtx.Done() return nil } func waitForDelay(opts *output.Options, parent context.Context, logger *zap.Logger) error { if opts.Delay <= 0 { return nil } logger.Sugar().Infow("Delaying connection.", "delay", opts.Delay) if err := timed.Wait(parent, opts.Delay); err != nil { return fmt.Errorf("delay waiting period was interrupted: %w", err) } return nil } func setFlagFromEnv(flag *pflag.Flag) { envVar := strings.ToUpper(flag.Name) envVar = strings.ReplaceAll(envVar, "-", "_") envVar = "STREAM_" + envVar flag.Usage = fmt.Sprintf("%v [env %v]", flag.Usage, envVar) if value := os.Getenv(envVar); value != "" { flag.Value.Set(value) } }