tools/lambda-compat/main.go (464 lines of code) (raw):
package lambdacompat
/*
Copyright 2022 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"text/template"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
aws "github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
awssts "github.com/aws/aws-sdk-go-v2/service/sts"
awsststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"google.golang.org/api/idtoken"
)
type LambdaCompat interface {
Start() error
}
type LambdaCompatCommand struct {
Context context.Context
Command string
Args []string
Environment []string
}
type LambdaResponse struct {
Result []byte
Error error
}
type LambdaRequest struct {
RequestId string
Body []byte
Result chan LambdaResponse
}
type LambdaCompatServer struct {
port int
Command []string
commands []LambdaCompatCommand
requestChan chan *LambdaRequest
requests map[string]*LambdaRequest
refreshChan chan bool
Region string
ProjectNumber string
Service string
Audience string
RoleArn string
JsonTransform *template.Template
Processing sync.Mutex
}
type LambdaRunHandler struct {
server *LambdaCompatServer
}
type LambdaCompatHandler struct {
server *LambdaCompatServer
}
type LambdaErrorHandler struct {
server *LambdaCompatServer
}
type LambdaRestartError struct{}
func (m *LambdaRestartError) Error() string {
return "process needs to be restarted"
}
var awsToken *awsststypes.Credentials = nil
func getHttpRequest(r *http.Request, status int) *zerolog.Event {
zld := zerolog.Dict().
Str("requestMethod", r.Method).
Str("requestUrl", r.URL.String()).
Str("remoteIp", r.RemoteAddr)
if r.UserAgent() != "" {
zld = zld.Str("userAgent", r.UserAgent())
}
if r.Referer() != "" {
zld = zld.Str("referer", r.Referer())
}
if status > 0 {
zld = zld.Int("status", status)
}
return zld
}
func marshalJson(data interface{}) string {
ret, err := json.Marshal(data)
if err == nil {
return string(ret)
}
return ""
}
func (h LambdaRunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.server.Processing.Lock()
defer h.server.Processing.Unlock()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Error().Err(err).Dict("httpRequest", getHttpRequest(r, http.StatusBadRequest)).Msg("Error reading body")
http.Error(w, "can't read body", http.StatusBadRequest)
return
}
if h.server.JsonTransform != nil {
var jsonBody interface{}
err = json.Unmarshal(body, &jsonBody)
if err != nil {
log.Error().Dict("httpRequest", getHttpRequest(r, http.StatusBadRequest)).Str("body", string(body)).Err(err).Msg("Failed to unmarshal JSON body for transformation")
http.Error(w, "failed to unmarshal JSON body", http.StatusBadRequest)
return
}
templateVars := map[string]interface{}{
"Body": jsonBody,
"URL": r.URL,
"Method": r.Method,
"RemoteAddr": r.RemoteAddr,
"Headers": r.Header,
}
var buf bytes.Buffer
err = h.server.JsonTransform.Execute(&buf, templateVars)
if err != nil {
log.Error().Dict("httpRequest", getHttpRequest(r, http.StatusBadRequest)).Str("body", string(body)).Err(err).Msg("JSON transformation template failed")
http.Error(w, "failed to render transformed JSON body", http.StatusBadRequest)
}
body = buf.Bytes()
}
requestId := uuid.New()
req := LambdaRequest{
RequestId: requestId.String(),
Body: body,
Result: make(chan LambdaResponse),
}
h.server.requests[req.RequestId] = &req
log.Info().Str("spanId", req.RequestId).Dict("httpRequest", getHttpRequest(r, 0)).Msg("Starting to process new request")
h.server.requestChan <- &req
output := <-req.Result
if output.Error == nil {
log.Info().Str("spanId", req.RequestId).Dict("httpRequest", getHttpRequest(r, http.StatusOK)).Msg("Request processed successfully")
w.Write(output.Result)
} else {
log.Error().RawJSON("error", output.Result).Str("spanId", req.RequestId).Dict("httpRequest", getHttpRequest(r, http.StatusInternalServerError)).Msg("Request failed")
w.WriteHeader(http.StatusInternalServerError)
w.Write(output.Result)
}
}
func (h LambdaCompatHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.Split(r.URL.Path, "/")
if path[len(path)-1] == "next" {
var req *LambdaRequest
select {
case req = <-h.server.requestChan:
log.Info().Str("spanId", req.RequestId).Dict("httpRequest", getHttpRequest(r, http.StatusOK).Int("requestSize", len(req.Body))).Msg("Sending next invocation")
w.Header().Set("Lambda-Runtime-Aws-Request-Id", req.RequestId)
arn := fmt.Sprintf("arn:aws:lambda:%s:%s:function:%s", h.server.Region, h.server.ProjectNumber, h.server.Service)
w.Header().Set("Lambda-Runtime-Invoked-Function-Arn", arn)
// Maximum duration for a Lambda function is 5 minutes
deadlineMs := (1000 * 300) + time.Now().UnixNano()/int64(time.Millisecond)
w.Header().Set("Lambda-Runtime-Deadline-Ms", fmt.Sprintf("%d", deadlineMs))
w.Header().Set("Content-Type", "application/json")
if len(req.Body) > 0 {
w.Write(req.Body)
} else { // Matches SAM CLI behaviour
w.Write([]byte("{}"))
}
return
case _ = <-h.server.refreshChan:
return
}
}
if path[len(path)-1] == "response" || path[len(path)-1] == "error" {
requestId := path[len(path)-2]
if req, ok := h.server.requests[requestId]; ok {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Error().Err(err).Dict("httpRequest", getHttpRequest(r, http.StatusBadRequest)).Msg("Error reading body")
response := LambdaResponse{
Result: body,
Error: err,
}
req.Result <- response
http.Error(w, "can't read body", http.StatusBadRequest)
return
}
err = nil
if path[len(path)-1] == "error" {
err = fmt.Errorf("invocation returned error")
}
response := LambdaResponse{
Result: body,
Error: err,
}
req.Result <- response
w.WriteHeader(http.StatusAccepted)
return
}
}
w.WriteHeader(http.StatusBadRequest)
}
func (h LambdaErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Error().Err(err).Dict("httpRequest", getHttpRequest(r, http.StatusBadRequest)).Msg("Error reading body")
http.Error(w, "can't read body", http.StatusBadRequest)
return
}
log.Error().RawJSON("error", body).Str("errorType", r.Header.Get("Lambda-Runtime-Function-Error-Type")).Msg("Runtime initialization failed")
w.WriteHeader(http.StatusAccepted)
}
func (s LambdaCompatServer) createRunServer(port int) *http.Server {
mux := http.NewServeMux()
mux.Handle("/", LambdaRunHandler{
server: &s,
})
server := http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
}
return &server
}
func (s LambdaCompatServer) createCompatServer(port int) *http.Server {
mux := http.NewServeMux()
mux.Handle("/2018-06-01/runtime/invocation/", LambdaCompatHandler{
server: &s,
})
mux.Handle("/2018-06-01/runtime/init/error", LambdaErrorHandler{
server: &s,
})
server := http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
}
return &server
}
func (s LambdaCompatServer) isTokenExpired() bool {
if s.RoleArn == "" {
return false
}
timeNow := time.Now().Add(time.Duration(-5) * time.Minute)
if awsToken == nil || timeNow.After(*awsToken.Expiration) {
return true
}
return false
}
func (s LambdaCompatServer) refreshIdToken(ctx context.Context, aud string) error {
if s.RoleArn == "" {
return fmt.Errorf("no role ARN defined (set AWS_ROLE_ARN environment variable)")
}
if s.isTokenExpired() {
log.Info().Str("audience", s.Audience).Str("roleArn", s.RoleArn).Msg("Getting AWS session token")
ts, err := idtoken.NewTokenSource(ctx, aud)
if err != nil {
return err
}
tok, err := ts.Token()
if err != nil {
return err
}
log.Info().Str("accessToken", tok.AccessToken).Str("tokenType", tok.TokenType).Str("expiry", tok.Expiry.String()).Msg("Token")
// Validation is mainly performed to retrieve token details
payload, err := idtoken.Validate(ctx, tok.AccessToken, aud)
if err != nil {
return err
}
log.Debug().Str("issuer", payload.Issuer).
Str("audience", payload.Audience).
Int64("expires", payload.Expires).
Int64("issuedAt", payload.IssuedAt).
Str("Subject", payload.Audience).
Fields(map[string]interface{}{"claims": payload.Claims}).
Msg("ID token")
var configs = []func(*awsconfig.LoadOptions) error{
awsconfig.WithRegion(s.Region),
}
if e := log.Debug(); e.Enabled() {
configs = append(configs, awsconfig.WithClientLogMode(aws.LogRetries|aws.LogRequestWithBody))
}
cfg, err := awsconfig.LoadDefaultConfig(ctx, configs...)
if err != nil {
return err
}
stsSvc := awssts.NewFromConfig(cfg)
sessionName := payload.Claims["email"].(string)
sessionName = strings.Replace(sessionName, ".iam.gserviceaccount.com", "", 1)
if len(sessionName) > 64 {
sessionName = sessionName[0:63]
}
output, err := stsSvc.AssumeRoleWithWebIdentity(ctx, &awssts.AssumeRoleWithWebIdentityInput{
RoleArn: aws.String(s.RoleArn),
RoleSessionName: aws.String(sessionName),
DurationSeconds: aws.Int32(3600),
WebIdentityToken: &tok.AccessToken,
})
if err != nil {
return err
}
awsToken = output.Credentials
}
return nil
}
func NewLambdaCompatServer(command []string, port int, region string, projectNum string, service string, audience string, roleArn string, jsonTransform string) *LambdaCompatServer {
var tmpl *template.Template = nil
var err error
if jsonTransform != "" {
templateFuncs := template.FuncMap{
"ToJson": marshalJson,
}
tmpl, err = template.New(filepath.Base(jsonTransform)).Funcs(templateFuncs).ParseFiles(jsonTransform)
if err != nil {
panic(err)
}
}
return &LambdaCompatServer{
Command: command,
port: port,
requestChan: make(chan *LambdaRequest, 100),
requests: make(map[string]*LambdaRequest, 0),
refreshChan: make(chan bool),
Region: region,
ProjectNumber: projectNum,
Service: service,
Audience: audience,
RoleArn: roleArn,
JsonTransform: tmpl,
}
}
func (c LambdaCompatCommand) Run(s *LambdaCompatServer) error {
cmd := exec.CommandContext(c.Context, c.Command, c.Args...)
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("error getting stdout pipe: %w", err)
}
stdoutBuf := bufio.NewScanner(stdout)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("error getting stderr pipe: %w", err)
}
stderrBuf := bufio.NewScanner(stderr)
done := make(chan error)
stdoutChan := make(chan string)
stderrChan := make(chan string)
// Start stdout, stderr output goprocs
go func() {
for stdoutBuf.Scan() {
text := stdoutBuf.Text()
stdoutChan <- text
}
}()
go func() {
for stderrBuf.Scan() {
text := stderrBuf.Text()
stderrChan <- text
}
}()
// Set command environment
cmd.Env = os.Environ()
for _, e := range c.Environment {
cmd.Env = append(cmd.Env, e)
}
// Start command
if err := cmd.Start(); err != nil {
return err
}
var processKilled bool = false
ticker := time.NewTicker(5 * time.Second)
go func() {
for {
select {
case <-ticker.C:
if s.isTokenExpired() {
log.Info().Msg("AWS session token expired, refreshing and restarting command")
processKilled = true
s.Processing.Lock()
s.refreshChan <- true
cmd.Process.Kill()
}
break
}
}
}()
go func() {
done <- cmd.Wait()
}()
for {
select {
case line := <-stdoutChan:
fmt.Fprintf(os.Stdout, "%s\n", line)
case line := <-stderrChan:
fmt.Fprintf(os.Stderr, "%s\n", line)
case err := <-done:
ticker.Stop()
if processKilled {
s.Processing.Unlock()
return &LambdaRestartError{}
}
return err
}
}
}
func (s LambdaCompatServer) Start() error {
log.Info().Msg("Cloud Run Lambda compatibility layer starting...")
ctx := context.Background()
wg := new(sync.WaitGroup)
wg.Add(2)
go func() {
defer wg.Done()
log.Info().Int("port", s.port).Msg("Listening to incoming requests")
server := s.createRunServer(s.port)
err := server.ListenAndServe()
if err != nil {
log.Fatal().Err(err)
}
}()
go func() {
defer wg.Done()
log.Info().Int("port", s.port+1).Msg("Emulating Lambda environment")
server := s.createCompatServer(s.port + 1)
err := server.ListenAndServe()
if err != nil {
log.Fatal().Err(err)
}
}()
lambdaEmulationAPI := fmt.Sprintf("127.0.0.1:%d", s.port+1)
for {
err := s.startCommand(ctx, lambdaEmulationAPI)
if err != nil {
if _, ok := err.(*LambdaRestartError); ok {
continue
}
return err
}
}
wg.Wait()
return nil
}
func (s LambdaCompatServer) startCommand(ctx context.Context, lambdaEmulationAPI string) error {
if s.Audience != "" {
err := s.refreshIdToken(ctx, s.Audience)
if err != nil {
log.Error().Err(err).Msg("Failed to get OIDC token")
}
}
s.commands = make([]LambdaCompatCommand, 1)
environment := []string{
fmt.Sprintf("AWS_LAMBDA_RUNTIME_API=%s", lambdaEmulationAPI),
fmt.Sprintf("AWS_REGION=%s", s.Region),
fmt.Sprintf("AWS_DEFAULT_REGION=%s", s.Region),
}
if awsToken != nil {
environment = append(environment, fmt.Sprintf("AWS_ACCESS_KEY_ID=%s", *awsToken.AccessKeyId))
environment = append(environment, fmt.Sprintf("AWS_SECRET_ACCESS_KEY=%s", *awsToken.SecretAccessKey))
environment = append(environment, fmt.Sprintf("AWS_SESSION_TOKEN=%s", *awsToken.SessionToken))
}
var params []string = []string{}
if len(s.Command) > 1 {
params = s.Command[1 : len(os.Args)-1]
}
s.commands[0] = LambdaCompatCommand{
Context: ctx,
Command: s.Command[0],
Args: params,
Environment: environment,
}
err := s.commands[0].Run(&s)
if err != nil {
return err
}
return nil
}