lambda/handler.go (255 lines of code) (raw):
// Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
package lambda
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil" // nolint:staticcheck
"reflect"
"strings"
"github.com/aws/aws-lambda-go/lambda/handlertrace"
)
type Handler interface {
Invoke(ctx context.Context, payload []byte) ([]byte, error)
}
type handlerOptions struct {
handlerFunc
baseContext context.Context
contextValues map[interface{}]interface{}
jsonRequestUseNumber bool
jsonRequestDisallowUnknownFields bool
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
jsonResponseIndentValue string
enableSIGTERM bool
sigtermCallbacks []func()
}
type Option func(*handlerOptions)
// WithContext is a HandlerOption that sets the base context for all invocations of the handler.
//
// Usage:
//
// lambda.StartWithOptions(
// func (ctx context.Context) (string, error) {
// return ctx.Value("foo"), nil
// },
// lambda.WithContext(context.WithValue(context.Background(), "foo", "bar"))
// )
func WithContext(ctx context.Context) Option {
return Option(func(h *handlerOptions) {
h.baseContext = ctx
})
}
// WithContextValue adds a value to the handler context.
// If a base context was set using WithContext, that base is used as the parent.
//
// Usage:
//
// lambda.StartWithOptions(
// func (ctx context.Context) (string, error) {
// return ctx.Value("foo"), nil
// },
// lambda.WithContextValue("foo", "bar")
// )
func WithContextValue(key interface{}, value interface{}) Option {
return Option(func(h *handlerOptions) {
h.contextValues[key] = value
})
}
// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
//
// Usage:
//
// lambda.StartWithOptions(
// func () (string, error) {
// return "<html><body>hello!></body></html>", nil
// },
// lambda.WithSetEscapeHTML(true),
// )
func WithSetEscapeHTML(escapeHTML bool) Option {
return Option(func(h *handlerOptions) {
h.jsonResponseEscapeHTML = escapeHTML
})
}
// WithSetIndent sets the SetIndent argument on the underling json encoder
//
// Usage:
//
// lambda.StartWithOptions(
// func (event any) (any, error) {
// return event, nil
// },
// lambda.WithSetIndent(">"," "),
// )
func WithSetIndent(prefix, indent string) Option {
return Option(func(h *handlerOptions) {
h.jsonResponseIndentPrefix = prefix
h.jsonResponseIndentValue = indent
})
}
// WithUseNumber sets the UseNumber option on the underlying json decoder
//
// Usage:
//
// lambda.StartWithOptions(
// func (event any) (any, error) {
// return event, nil
// },
// lambda.WithUseNumber(true)
// )
func WithUseNumber(useNumber bool) Option {
return Option(func(h *handlerOptions) {
h.jsonRequestUseNumber = useNumber
})
}
// WithUseNumber sets the DisallowUnknownFields option on the underlying json decoder
//
// Usage:
//
// lambda.StartWithOptions(
// func (event any) (any, error) {
// return event, nil
// },
// lambda.WithDisallowUnknownFields(true)
// )
func WithDisallowUnknownFields(disallowUnknownFields bool) Option {
return Option(func(h *handlerOptions) {
h.jsonRequestDisallowUnknownFields = disallowUnknownFields
})
}
// WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown.
// SIGKILL will occur ~500ms after SIGTERM.
// Optionally, an array of callback functions to run on SIGTERM may be provided.
//
// Usage:
//
// lambda.StartWithOptions(
// func (event any) (any, error) {
// return event, nil
// },
// lambda.WithEnableSIGTERM(func() {
// log.Print("function container shutting down...")
// })
// )
func WithEnableSIGTERM(callbacks ...func()) Option {
return Option(func(h *handlerOptions) {
h.sigtermCallbacks = append(h.sigtermCallbacks, callbacks...)
h.enableSIGTERM = true
})
}
// handlerTakesContext returns whether the handler takes a context.Context as its first argument.
func handlerTakesContext(handler reflect.Type) (bool, error) {
switch handler.NumIn() {
case 0:
return false, nil
case 1:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
if argumentType.Kind() != reflect.Interface {
return false, nil
}
// handlers like func(event any) are valid.
if argumentType.NumMethod() == 0 {
return false, nil
}
if !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
return false, fmt.Errorf("handler takes an interface, but it is not context.Context: %q", argumentType.Name())
}
return true, nil
case 2:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
if argumentType.Kind() != reflect.Interface || !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
}
return true, nil
}
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
}
func validateReturns(handler reflect.Type) error {
errorType := reflect.TypeOf((*error)(nil)).Elem()
switch n := handler.NumOut(); {
case n > 2:
return fmt.Errorf("handler may not return more than two values")
case n > 1:
if !handler.Out(1).Implements(errorType) {
return fmt.Errorf("handler returns two values, but the second does not implement error")
}
case n == 1:
if !handler.Out(0).Implements(errorType) {
return fmt.Errorf("handler returns a single value, but it does not implement error")
}
}
return nil
}
// NewHandler creates a base lambda handler from the given handler function. The
// returned Handler performs JSON serialization and deserialization, and
// delegates to the input handler function. The handler function parameter must
// satisfy the rules documented by Start. If handlerFunc is not a valid
// handler, the returned Handler simply reports the validation error.
func NewHandler(handlerFunc interface{}) Handler {
return NewHandlerWithOptions(handlerFunc)
}
// NewHandlerWithOptions creates a base lambda handler from the given handler function. The
// returned Handler performs JSON serialization and deserialization, and
// delegates to the input handler function. The handler function parameter must
// satisfy the rules documented by Start. If handlerFunc is not a valid
// handler, the returned Handler simply reports the validation error.
func NewHandlerWithOptions(handlerFunc interface{}, options ...Option) Handler {
return newHandler(handlerFunc, options...)
}
func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
if h, ok := handlerFunc.(*handlerOptions); ok {
return h
}
h := &handlerOptions{
baseContext: context.Background(),
contextValues: map[interface{}]interface{}{},
jsonResponseEscapeHTML: false,
jsonResponseIndentPrefix: "",
jsonResponseIndentValue: "",
}
for _, option := range options {
option(h)
}
for k, v := range h.contextValues {
h.baseContext = context.WithValue(h.baseContext, k, v)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
h.handlerFunc = reflectHandler(handlerFunc, h)
return h
}
type handlerFunc func(context.Context, []byte) (io.Reader, error)
// back-compat for the rpc mode
func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
response, err := h(ctx, payload)
if err != nil {
return nil, err
}
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
if response, ok := response.(io.Closer); ok {
defer response.Close()
}
// optimization: if the response is a *bytes.Buffer, a copy can be eliminated
switch response := response.(type) {
case *jsonOutBuffer:
return response.Bytes(), nil
case *bytes.Buffer:
return response.Bytes(), nil
}
b, err := ioutil.ReadAll(response)
if err != nil {
return nil, err
}
return b, nil
}
func errorHandler(err error) handlerFunc {
return func(_ context.Context, _ []byte) (io.Reader, error) {
return nil, err
}
}
type jsonOutBuffer struct {
*bytes.Buffer
}
func (j *jsonOutBuffer) ContentType() string {
return contentTypeJSON
}
func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
if f == nil {
return errorHandler(errors.New("handler is nil"))
}
// back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped
if handler, ok := f.(Handler); ok {
return func(ctx context.Context, payload []byte) (io.Reader, error) {
b, err := handler.Invoke(ctx, payload)
if err != nil {
return nil, err
}
return bytes.NewBuffer(b), nil
}
}
handler := reflect.ValueOf(f)
handlerType := reflect.TypeOf(f)
if handlerType.Kind() != reflect.Func {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}
takesContext, err := handlerTakesContext(handlerType)
if err != nil {
return errorHandler(err)
}
if err := validateReturns(handlerType); err != nil {
return errorHandler(err)
}
out := &jsonOutBuffer{bytes.NewBuffer(nil)}
return func(ctx context.Context, payload []byte) (io.Reader, error) {
out.Reset()
in := bytes.NewBuffer(payload)
decoder := json.NewDecoder(in)
if h.jsonRequestUseNumber {
decoder.UseNumber()
}
if h.jsonRequestDisallowUnknownFields {
decoder.DisallowUnknownFields()
}
encoder := json.NewEncoder(out)
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue)
trace := handlertrace.FromContext(ctx)
// construct arguments
var args []reflect.Value
if takesContext {
args = append(args, reflect.ValueOf(ctx))
}
if (handlerType.NumIn() == 1 && !takesContext) || handlerType.NumIn() == 2 {
eventType := handlerType.In(handlerType.NumIn() - 1)
event := reflect.New(eventType)
if err := decoder.Decode(event.Interface()); err != nil {
return nil, err
}
if nil != trace.RequestEvent {
trace.RequestEvent(ctx, event.Elem().Interface())
}
args = append(args, event.Elem())
}
response := handler.Call(args)
// return the error, if any
if len(response) > 0 {
if errVal, ok := response[len(response)-1].Interface().(error); ok && errVal != nil {
return nil, errVal
}
}
// set the response value, if any
var val interface{}
if len(response) > 1 {
val = response[0].Interface()
if nil != trace.ResponseEvent {
trace.ResponseEvent(ctx, val)
}
}
// encode to JSON
if err := encoder.Encode(val); err != nil {
// if response is not JSON serializable, but the response type is a reader, return it as-is
if reader, ok := val.(io.Reader); ok {
return reader, nil
}
return nil, err
}
// if response value is an io.Reader, return it as-is
if reader, ok := val.(io.Reader); ok {
// back-compat, don't return the reader if the value serialized to a non-empty json
if strings.HasPrefix(out.String(), "{}") {
return reader, nil
}
}
// back-compat, strip the encoder's trailing newline unless WithSetIndent was used
if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" {
out.Truncate(out.Len() - 1)
}
return out, nil
}
}