openwhisk/forward_proxy.go (169 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package openwhisk
import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"github.com/google/uuid"
)
const OW_CODE_HASH = "__OW_CODE_HASH"
func (ap *ActionProxy) ForwardRunRequest(w http.ResponseWriter, r *http.Request) {
if ap.clientProxyData == nil {
sendError(w, http.StatusInternalServerError, "Send init first")
return
}
var runRequest runRequest
err := json.NewDecoder(r.Body).Decode(&runRequest)
if err != nil {
sendError(w, http.StatusBadRequest, fmt.Sprintf("Error decoding run body while forwarding request: %v", err))
return
}
newBody := runRequest
newBody.ActionCodeHash = ap.clientProxyData.ActionCodeHash
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(newBody)
if err != nil {
sendError(w, http.StatusBadRequest, fmt.Sprintf("Error encoding updated init body: %v", err))
return
}
bodyLen := buf.Len()
r.Body = io.NopCloser(bytes.NewBuffer(buf.Bytes()))
director := func(req *http.Request) {
req.Header = r.Header.Clone()
// Reset content length with the new body
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
req.ContentLength = int64(bodyLen)
req.URL.Scheme = ap.clientProxyData.ProxyURL.Scheme
req.URL.Host = ap.clientProxyData.ProxyURL.Host
req.Host = ap.clientProxyData.ProxyURL.Host
}
proxy := &httputil.ReverseProxy{Director: director}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
Debug("Error forwarding run request: %v", err)
sendError(w, http.StatusBadGateway, "Error forwarding run request. Check logs for details.")
}
proxy.ModifyResponse = func(response *http.Response) error {
if response.StatusCode == http.StatusOK {
// Decode the response
var remoteReponse RemoteRunResponse
err := json.NewDecoder(response.Body).Decode(&remoteReponse)
if err != nil {
Debug("Error decoding remote response: %v", err)
return err
}
// Write the logs to the client logs.
if _, err := ap.outFile.WriteString(remoteReponse.Out); err != nil {
Debug("Error writing remote response out to client: %v", err)
}
// Avoid spamming just the output guard if there is no error string
if remoteReponse.Err != OutputGuard {
if _, err := ap.errFile.WriteString(remoteReponse.Err); err != nil {
Debug("Error writing remote response err to client: %v", err)
}
}
// Keep the response body only
response.Body = io.NopCloser(bytes.NewReader(remoteReponse.Response))
// recalculate the content length
response.ContentLength = int64(len(remoteReponse.Response))
response.Header.Set("Content-Length", strconv.Itoa(len(remoteReponse.Response)))
} else {
Debug("Remote response status code: %d", response.StatusCode)
}
return nil
}
Debug("Forwarding run request with to %s", ap.clientProxyData.ProxyURL.String())
proxy.ServeHTTP(w, r)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
func (ap *ActionProxy) ForwardInitRequest(w http.ResponseWriter, r *http.Request) {
var initRequest initRequest
err := json.NewDecoder(r.Body).Decode(&initRequest)
if err != nil {
sendError(w, http.StatusBadRequest, fmt.Sprintf("Error decoding init body while forwarding request: %v", err))
return
}
Debug("Decoded init request: len: %d - main: %s", r.ContentLength, initRequest.Value.Main)
proxyData, err := parseMainFlag(initRequest.Value.Main)
if err != nil {
sendError(w, http.StatusBadRequest, err.Error())
return
}
// set the proxy data
ap.clientProxyData = proxyData
ap.clientProxyData.ProxyActionID = uuid.New().String()
newBody := initRequest
newBody.Value.Main = ap.clientProxyData.MainFunc
newBody.ProxiedActionID = ap.clientProxyData.ProxyActionID
codeHash := calculateCodeHash(initRequest.Value.Code)
if newBody.Value.Env == nil {
newBody.Value.Env = make(map[string]interface{})
}
newBody.Value.Env[OW_CODE_HASH] = codeHash
ap.clientProxyData.ActionCodeHash = codeHash
Debug("Set code hash: %s", codeHash)
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(newBody)
if err != nil {
sendError(w, http.StatusBadRequest, fmt.Sprintf("Error encoding updated init body: %v", err))
return
}
bodyLen := buf.Len()
r.Body = io.NopCloser(bytes.NewBuffer(buf.Bytes()))
director := func(req *http.Request) {
req.Header = r.Header.Clone()
// Reset content length with the new body
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
req.ContentLength = int64(bodyLen)
req.URL.Scheme = ap.clientProxyData.ProxyURL.Scheme
req.URL.Host = ap.clientProxyData.ProxyURL.Host
req.Host = ap.clientProxyData.ProxyURL.Host
}
proxy := &httputil.ReverseProxy{Director: director}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
Debug("Error forwarding init request: %v", err)
sendError(w, http.StatusBadGateway, "Error forwarding init request. Check logs for details.")
}
Debug("Forwarding init request to %s", ap.clientProxyData.ProxyURL.String())
proxy.ServeHTTP(w, r)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
func parseMainFlag(mainAtProxy string) (*ClientProxyData, error) {
proxyData := ClientProxyData{}
splitedMainAtProxy := strings.Split(mainAtProxy, "@")
var extractedURL string
if len(splitedMainAtProxy) == 2 {
proxyData.MainFunc = splitedMainAtProxy[0]
extractedURL = splitedMainAtProxy[1]
} else if len(splitedMainAtProxy) == 1 {
extractedURL = splitedMainAtProxy[0]
} else {
return nil, fmt.Errorf("invalid value for --main flag. Must be in the form of <main>@<proxy> or @<proxy>")
}
parsedUrl, err := parseMainURL(extractedURL)
if err != nil {
return nil, err
}
proxyData.ProxyURL = *parsedUrl
Debug("Parsed main flag. Main: %s, Proxy: %s", proxyData.MainFunc, proxyData.ProxyURL.String())
return &proxyData, nil
}
func parseMainURL(input string) (*url.URL, error) {
if input == "" {
return nil, fmt.Errorf("empty URL")
}
// Check if the input has a scheme, otherwise "https"
if !strings.Contains(input, "://") {
input = "https://" + input
}
// Parse the input URL
parsedURL, err := url.Parse(input)
if err != nil {
return nil, err
}
return parsedURL, nil
}
func calculateCodeHash(code string) string {
hash := md5.Sum([]byte(code))
return hex.EncodeToString(hash[:])
}