mito.go (509 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Package mito provides the logic for a main function and test infrastructure
// for a CEL-based message stream processor.
//
// The majority of the logic resides in the the lib package.
package mito
import (
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"reflect"
"regexp"
runtimedebug "runtime/debug"
"strings"
"github.com/goccy/go-yaml"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/interpreter"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"golang.org/x/oauth2/endpoints"
"golang.org/x/oauth2/google"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"github.com/elastic/mito/internal/httplog"
"github.com/elastic/mito/internal/rc"
"github.com/elastic/mito/lib"
)
const root = "state"
func Main() int {
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), `Usage of %s:
%[1]s [opts] <src.cel>
`, os.Args[0])
flag.PrintDefaults()
}
use := flag.String("use", "all", "libraries to use")
data := flag.String("data", "", "path to a JSON object holding input (exposed as the label "+root+")")
maxExecutions := flag.Int("max_executions", -1, "maximum number of evaluations, or no maximum if -1")
cfgPath := flag.String("cfg", "", "path to a YAML file holding run control configuration (see pkg.go.dev/github.com/elastic/mito/cmd/mito)")
insecure := flag.Bool("insecure", false, "disable TLS verification in the HTTP client")
logTrace := flag.Bool("log_requests", false, "log request traces to stderr (go1.21+)")
maxTraceBody := flag.Int("max_log_body", 1000, "maximum length of body logged in request traces (go1.21+)")
fold := flag.Bool("fold", false, "apply constant folding optimisation")
dumpState := flag.String("dump", "", "dump eval state ('always' or 'error')")
coverage := flag.String("coverage", "", "file to write an execution coverage report to (prefix if multiple executions are run)")
version := flag.Bool("version", false, "print version and exit")
flag.Parse()
if *version {
return printVersion()
}
if len(flag.Args()) != 1 {
flag.Usage()
return 2
}
libs := []cel.EnvOption{
cel.OptionalTypes(cel.OptionalTypesVersion(lib.OptionalTypesVersion)),
}
if *cfgPath != "" {
f, err := os.Open(*cfgPath)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
defer f.Close()
dec := yaml.NewDecoder(f)
var cfg Config
err = dec.Decode(&cfg)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
if len(cfg.Globals) != 0 {
libs = append(libs, lib.Globals(cfg.Globals))
}
if len(cfg.Regexps) != 0 {
regexps := make(map[string]*regexp.Regexp)
for name, expr := range cfg.Regexps {
re, err := regexp.Compile(expr)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
regexps[name] = re
}
libs = append(libs, lib.Regexp(regexps))
}
if len(cfg.XSDs) != 0 {
xsds := make(map[string]string)
for name, path := range cfg.XSDs {
b, err := os.ReadFile(path)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
xsds[name] = string(b)
}
xml, err := lib.XML(nil, xsds)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
libMap["xml"] = xml
}
var client *http.Client
httpOptions := lib.HTTPOptions{
Headers: cfg.HTTPHeaders,
MaxBodySize: cfg.MaxBodySize,
}
if cfg.Auth != nil {
switch auth := cfg.Auth; {
case auth.Basic != nil && auth.OAuth2 != nil:
fmt.Fprintln(os.Stderr, "configured basic authentication and OAuth2")
return 2
case auth.Basic != nil:
httpOptions.BasicAuth = auth.Basic
case auth.OAuth2 != nil:
client, err = oAuth2Client(*auth.OAuth2)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
}
}
if client != nil || !httpOptions.IsZero() {
ctx := context.Background()
libMap["http"] = lib.HTTPWithContextOpts(ctx, traceReqs(setClientInsecure(client, *insecure), *logTrace, *maxTraceBody), httpOptions)
}
if *maxExecutions == -1 && cfg.MaxExecutions != nil {
*maxExecutions = *cfg.MaxExecutions
}
}
if libMap["http"] == nil {
libMap["http"] = lib.HTTP(traceReqs(setClientInsecure(nil, *insecure), *logTrace, *maxTraceBody), nil, nil)
}
if libMap["xml"] == nil {
var err error
libMap["xml"], err = lib.XML(nil, nil)
if err != nil {
return 2
}
}
if *use == "all" {
for _, l := range libMap {
libs = append(libs, l)
}
} else {
for _, u := range strings.Split(*use, ",") {
l, ok := libMap[u]
if !ok {
fmt.Fprintf(os.Stderr, "no lib %q\n", u)
return 2
}
libs = append(libs, l)
}
}
b, err := os.ReadFile(flag.Args()[0])
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
var input interface{}
if *data != "" {
b, err := os.ReadFile(*data)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
err = json.Unmarshal(b, &input)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
input = map[string]interface{}{root: input}
}
var cov lib.Coverage
for n := int(0); *maxExecutions < 0 || n < *maxExecutions; n++ {
res, val, dump, c, err := eval(string(b), root, input, *fold, *dumpState != "", *coverage != "", libs...)
if err := cov.Merge(c); err != nil {
fmt.Fprintf(os.Stderr, "internal error merging coverage: %v\n", err)
return 1
}
if *dumpState == "always" {
fmt.Fprint(os.Stderr, dump)
}
if err != nil {
if *dumpState == "error" {
fmt.Fprint(os.Stderr, dump)
}
fmt.Fprintln(os.Stderr, err)
return 2
}
fmt.Println(res)
// Check if we want more. This can happen when we have a map
// and the map has a true boolean field, want_more.
state, ok := val.(map[string]any)
if !ok {
break
}
if more, _ := state["want_more"].(bool); !more {
break
}
input = map[string]any{"state": val}
}
if *coverage != "" {
f, err := os.Create(*coverage)
if err != nil {
fmt.Fprintf(os.Stderr, "internal error opening coverage file: %v\n", err)
return 1
}
defer func() {
f.Sync()
f.Close()
}()
_, err = f.WriteString(cov.String() + "\n")
if err != nil {
fmt.Fprintf(os.Stderr, "internal error writing coverage file: %v\n", err)
return 1
}
}
return 0
}
func printVersion() int {
bi, ok := runtimedebug.ReadBuildInfo()
if !ok {
fmt.Fprintln(os.Stderr, "no build info")
return 1
}
var revision, modified string
for _, bs := range bi.Settings {
switch bs.Key {
case "vcs.revision":
revision = bs.Value
case "vcs.modified":
modified = bs.Value
}
}
if revision == "" {
fmt.Println(bi.Main.Version)
return 0
}
switch modified {
case "true":
fmt.Println(bi.Main.Version, revision, "(modified)")
case "false":
fmt.Println(bi.Main.Version, revision)
default:
// This should never happen.
fmt.Println(bi.Main.Version, revision, modified)
}
return 0
}
// setClientInsecure returns an http.Client that will skip TLS certificate
// verification when insecure is true. If c is nil and insecure is true
// http.DefaultClient and http.DefaultTransport are used and will be mutated.
func setClientInsecure(c *http.Client, insecure bool) *http.Client {
if !insecure {
return c
}
if c == nil {
c = http.DefaultClient
}
if c.Transport == nil {
c.Transport = http.DefaultTransport
}
t, ok := c.Transport.(*http.Transport)
if !ok {
return c
}
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure}
c.Transport = t
return c
}
// traceReqs wraps c with a request trace logger that logs HTTP requests and
// their responses to stderr. If c is nil and trace is true
// http.DefaultClient and http.DefaultTransport are used and will be mutated.
func traceReqs(c *http.Client, trace bool, max int) *http.Client {
if !trace {
return c
}
if c == nil {
c = http.DefaultClient
}
if c.Transport == nil {
c.Transport = http.DefaultTransport
}
c.Transport = httplog.NewLoggingRoundTripper(c.Transport, max)
return c
}
var (
libMap = map[string]cel.EnvOption{
"collections": lib.Collections(),
"crypto": lib.Crypto(),
"json": lib.JSON(nil),
"time": lib.Time(),
"try": lib.Try(),
"debug": lib.Debug(debug),
"file": lib.File(mimetypes),
"mime": lib.MIME(mimetypes),
"http": nil, // This will be populated by Main.
"limit": lib.Limit(limitPolicies),
"strings": lib.Strings(),
"printf": lib.Printf(),
"xml": nil, // This will be populated by Main.
}
mimetypes = map[string]interface{}{
"text/rot13": func(r io.Reader) io.Reader { return rot13{r} },
"text/upper": toUpper,
"application/gzip": func(r io.Reader) (io.Reader, error) { return gzip.NewReader(r) },
"text/csv; header=present": lib.CSVHeader,
"text/csv; header=absent": lib.CSVNoHeader,
"application/x-ndjson": lib.NDJSON,
"application/zip": lib.Zip,
}
limitPolicies = map[string]lib.LimitPolicy{
"okta": lib.OktaRateLimit,
"draft": lib.DraftRateLimit,
}
)
func debug(tag string, value any) {
level := "DEBUG"
if _, ok := value.(error); ok {
level = "ERROR"
}
fmt.Fprintf(os.Stderr, "%s: logging %q: %v\n", level, tag, value)
}
func eval(src, root string, input interface{}, fold, details, coverage bool, libs ...cel.EnvOption) (string, any, *lib.Dump, *lib.Coverage, error) {
prg, ast, cov, err := compile(src, root, fold, details, coverage, libs...)
if err != nil {
return "", nil, nil, nil, fmt.Errorf("failed program instantiation: %v", err)
}
res, val, det, err := run(prg, ast, false, input)
var dump *lib.Dump
if details {
dump = lib.NewDump(ast, det)
}
return res, val, dump, cov, err
}
func compile(src, root string, fold, details, coverage bool, libs ...cel.EnvOption) (cel.Program, *cel.Ast, *lib.Coverage, error) {
opts := append([]cel.EnvOption{
cel.Declarations(decls.NewVar(root, decls.Dyn)),
}, libs...)
env, err := cel.NewEnv(opts...)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create env: %v", err)
}
ast, iss := env.Compile(src)
if iss.Err() != nil {
return nil, nil, nil, fmt.Errorf("failed compilation: %v", iss.Err())
}
if fold {
folder, err := cel.NewConstantFoldingOptimizer()
if err != nil {
return nil, nil, nil, fmt.Errorf("failed folding optimization: %v", err)
}
ast, iss = cel.NewStaticOptimizer(folder).Optimize(env, ast)
if iss.Err() != nil {
return nil, nil, nil, fmt.Errorf("failed optimization: %v", iss.Err())
}
}
var cov *lib.Coverage
var progOpts []cel.ProgramOption
if coverage {
cov = lib.NewCoverage(ast)
progOpts = []cel.ProgramOption{cov.ProgramOption()}
}
if details {
progOpts = append(progOpts, cel.EvalOptions(cel.OptTrackState))
}
prg, err := env.Program(ast, progOpts...)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed program instantiation: %v", err)
}
return prg, ast, cov, nil
}
func run(prg cel.Program, ast *cel.Ast, fast bool, input interface{}) (string, any, *cel.EvalDetails, error) {
if input == nil {
input = interpreter.EmptyActivation()
}
out, det, err := prg.Eval(input)
if err != nil {
return "", nil, det, fmt.Errorf("failed eval: %v", lib.DecoratedError{AST: ast, Err: err})
}
v, err := out.ConvertToNative(reflect.TypeOf(&structpb.Value{}))
if err != nil {
return "", nil, det, fmt.Errorf("failed proto conversion: %v", err)
}
val := v.(*structpb.Value).AsInterface()
if fast {
b, err := protojson.MarshalOptions{}.Marshal(v.(proto.Message))
if err != nil {
return "", nil, det, fmt.Errorf("failed native conversion: %v", err)
}
return string(b), val, det, nil
}
var buf strings.Builder
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
enc.SetIndent("", "\t")
err = enc.Encode(val)
return strings.TrimRight(buf.String(), "\n"), val, det, err
}
// rot13 is provided for testing purposes.
type rot13 struct {
r io.Reader
}
func (r rot13) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
for i, b := range p[:n] {
var base byte
switch {
case 'A' <= b && b <= 'Z':
base = 'A'
case 'a' <= b && b <= 'z':
base = 'a'
default:
continue
}
p[i] = ((b - base + 13) % 26) + base
}
return n, err
}
func toUpper(p []byte) {
for i, b := range p {
if 'a' <= b && b <= 'z' {
p[i] &^= 'a' - 'A'
}
}
}
type (
Config = rc.Config
AuthConfig = rc.AuthConfig
OAuth2 = rc.OAuth2Config
)
func oAuth2Client(cfg OAuth2) (*http.Client, error) {
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{})
switch prov := strings.ToLower(cfg.Provider); prov {
case "":
if cfg.User != "" || cfg.Password != "" {
var clientSecret string
if cfg.ClientSecret != nil {
clientSecret = *cfg.ClientSecret
}
oauth2cfg := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: clientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: cfg.TokenURL,
AuthStyle: oauth2.AuthStyleAutoDetect,
},
}
token, err := oauth2cfg.PasswordCredentialsToken(ctx, cfg.User, cfg.Password)
if err != nil {
return nil, fmt.Errorf("oauth2: error loading credentials using user and password: %w", err)
}
return oauth2cfg.Client(ctx, token), nil
}
fallthrough
case "azure":
var token string
if prov == "azure" {
if cfg.TokenURL == "" {
token = endpoints.AzureAD(cfg.AzureTenantID).TokenURL
}
if cfg.AzureResource != "" {
if cfg.EndpointParams == nil {
cfg.EndpointParams = make(url.Values)
}
cfg.EndpointParams.Set("resource", cfg.AzureResource)
}
}
var clientSecret string
if cfg.ClientSecret != nil {
clientSecret = *cfg.ClientSecret
}
return (&clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: clientSecret,
TokenURL: token,
Scopes: cfg.Scopes,
EndpointParams: cfg.EndpointParams,
}).Client(ctx), nil
case "google":
creds, err := google.FindDefaultCredentials(ctx, cfg.Scopes...)
if err == nil {
return nil, fmt.Errorf("oauth2: error loading default credentials: %w", err)
}
cfg.GoogleCredentialsJSON = string(creds.JSON)
if cfg.GoogleJWTFile != "" {
b, err := os.ReadFile(cfg.GoogleJWTFile)
if err != nil {
return nil, err
}
cfg.GoogleJWTJSON = string(b)
}
if cfg.GoogleJWTJSON != "" {
if !json.Valid([]byte(cfg.GoogleJWTJSON)) {
return nil, fmt.Errorf("invalid google jwt: %s", cfg.GoogleJWTJSON)
}
googCfg, err := google.JWTConfigFromJSON([]byte(cfg.GoogleJWTJSON), cfg.Scopes...)
if err != nil {
return nil, fmt.Errorf("oauth2: error loading jwt credentials: %w", err)
}
googCfg.Subject = cfg.GoogleDelegatedAccount
return googCfg.Client(ctx), nil
}
creds, err = google.CredentialsFromJSON(ctx, []byte(cfg.GoogleCredentialsJSON), cfg.Scopes...)
if err != nil {
return nil, fmt.Errorf("oauth2: error loading credentials: %w", err)
}
return oauth2.NewClient(ctx, creds.TokenSource), nil
default:
return nil, errors.New("oauth2: unknown provider")
}
}