openwhisk/forward_proxy.go (169 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package openwhisk import ( "bytes" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/http/httputil" "net/url" "strconv" "strings" "github.com/google/uuid" ) const OW_CODE_HASH = "__OW_CODE_HASH" func (ap *ActionProxy) ForwardRunRequest(w http.ResponseWriter, r *http.Request) { if ap.clientProxyData == nil { sendError(w, http.StatusInternalServerError, "Send init first") return } var runRequest runRequest err := json.NewDecoder(r.Body).Decode(&runRequest) if err != nil { sendError(w, http.StatusBadRequest, fmt.Sprintf("Error decoding run body while forwarding request: %v", err)) return } newBody := runRequest newBody.ActionCodeHash = ap.clientProxyData.ActionCodeHash var buf bytes.Buffer err = json.NewEncoder(&buf).Encode(newBody) if err != nil { sendError(w, http.StatusBadRequest, fmt.Sprintf("Error encoding updated init body: %v", err)) return } bodyLen := buf.Len() r.Body = io.NopCloser(bytes.NewBuffer(buf.Bytes())) director := func(req *http.Request) { req.Header = r.Header.Clone() // Reset content length with the new body req.Header.Set("Content-Length", strconv.Itoa(bodyLen)) req.ContentLength = int64(bodyLen) req.URL.Scheme = ap.clientProxyData.ProxyURL.Scheme req.URL.Host = ap.clientProxyData.ProxyURL.Host req.Host = ap.clientProxyData.ProxyURL.Host } proxy := &httputil.ReverseProxy{Director: director} proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { Debug("Error forwarding run request: %v", err) sendError(w, http.StatusBadGateway, "Error forwarding run request. Check logs for details.") } proxy.ModifyResponse = func(response *http.Response) error { if response.StatusCode == http.StatusOK { // Decode the response var remoteReponse RemoteRunResponse err := json.NewDecoder(response.Body).Decode(&remoteReponse) if err != nil { Debug("Error decoding remote response: %v", err) return err } // Write the logs to the client logs. if _, err := ap.outFile.WriteString(remoteReponse.Out); err != nil { Debug("Error writing remote response out to client: %v", err) } // Avoid spamming just the output guard if there is no error string if remoteReponse.Err != OutputGuard { if _, err := ap.errFile.WriteString(remoteReponse.Err); err != nil { Debug("Error writing remote response err to client: %v", err) } } // Keep the response body only response.Body = io.NopCloser(bytes.NewReader(remoteReponse.Response)) // recalculate the content length response.ContentLength = int64(len(remoteReponse.Response)) response.Header.Set("Content-Length", strconv.Itoa(len(remoteReponse.Response))) } else { Debug("Remote response status code: %d", response.StatusCode) } return nil } Debug("Forwarding run request with to %s", ap.clientProxyData.ProxyURL.String()) proxy.ServeHTTP(w, r) if f, ok := w.(http.Flusher); ok { f.Flush() } } func (ap *ActionProxy) ForwardInitRequest(w http.ResponseWriter, r *http.Request) { var initRequest initRequest err := json.NewDecoder(r.Body).Decode(&initRequest) if err != nil { sendError(w, http.StatusBadRequest, fmt.Sprintf("Error decoding init body while forwarding request: %v", err)) return } Debug("Decoded init request: len: %d - main: %s", r.ContentLength, initRequest.Value.Main) proxyData, err := parseMainFlag(initRequest.Value.Main) if err != nil { sendError(w, http.StatusBadRequest, err.Error()) return } // set the proxy data ap.clientProxyData = proxyData ap.clientProxyData.ProxyActionID = uuid.New().String() newBody := initRequest newBody.Value.Main = ap.clientProxyData.MainFunc newBody.ProxiedActionID = ap.clientProxyData.ProxyActionID codeHash := calculateCodeHash(initRequest.Value.Code) if newBody.Value.Env == nil { newBody.Value.Env = make(map[string]interface{}) } newBody.Value.Env[OW_CODE_HASH] = codeHash ap.clientProxyData.ActionCodeHash = codeHash Debug("Set code hash: %s", codeHash) var buf bytes.Buffer err = json.NewEncoder(&buf).Encode(newBody) if err != nil { sendError(w, http.StatusBadRequest, fmt.Sprintf("Error encoding updated init body: %v", err)) return } bodyLen := buf.Len() r.Body = io.NopCloser(bytes.NewBuffer(buf.Bytes())) director := func(req *http.Request) { req.Header = r.Header.Clone() // Reset content length with the new body req.Header.Set("Content-Length", strconv.Itoa(bodyLen)) req.ContentLength = int64(bodyLen) req.URL.Scheme = ap.clientProxyData.ProxyURL.Scheme req.URL.Host = ap.clientProxyData.ProxyURL.Host req.Host = ap.clientProxyData.ProxyURL.Host } proxy := &httputil.ReverseProxy{Director: director} proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { Debug("Error forwarding init request: %v", err) sendError(w, http.StatusBadGateway, "Error forwarding init request. Check logs for details.") } Debug("Forwarding init request to %s", ap.clientProxyData.ProxyURL.String()) proxy.ServeHTTP(w, r) if f, ok := w.(http.Flusher); ok { f.Flush() } } func parseMainFlag(mainAtProxy string) (*ClientProxyData, error) { proxyData := ClientProxyData{} splitedMainAtProxy := strings.Split(mainAtProxy, "@") var extractedURL string if len(splitedMainAtProxy) == 2 { proxyData.MainFunc = splitedMainAtProxy[0] extractedURL = splitedMainAtProxy[1] } else if len(splitedMainAtProxy) == 1 { extractedURL = splitedMainAtProxy[0] } else { return nil, fmt.Errorf("invalid value for --main flag. Must be in the form of <main>@<proxy> or @<proxy>") } parsedUrl, err := parseMainURL(extractedURL) if err != nil { return nil, err } proxyData.ProxyURL = *parsedUrl Debug("Parsed main flag. Main: %s, Proxy: %s", proxyData.MainFunc, proxyData.ProxyURL.String()) return &proxyData, nil } func parseMainURL(input string) (*url.URL, error) { if input == "" { return nil, fmt.Errorf("empty URL") } // Check if the input has a scheme, otherwise "https" if !strings.Contains(input, "://") { input = "https://" + input } // Parse the input URL parsedURL, err := url.Parse(input) if err != nil { return nil, err } return parsedURL, nil } func calculateCodeHash(code string) string { hash := md5.Sum([]byte(code)) return hex.EncodeToString(hash[:]) }