main.go (283 lines of code) (raw):
// Copyright 2021 the Cloud Run Proxy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package main is the entrypoint for cloud-run-proxy. It starts the proxy
// server.
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"flag"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/GoogleCloudPlatform/cloud-run-proxy/internal/version"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
)
type contextKey string
const contextKeyError = contextKey("error")
const ADCHintMessage = "If you are trying to authenticate using gcloud, try running `gcloud auth login --update-adc` first then restart the proxy."
var userAgent = version.Name + "/" + version.Version + " (" + version.OSArch + ")"
var (
flagHost = flag.String("host", "",
"Cloud Run host for which to proxy")
flagBind = flag.String("bind", "127.0.0.1:8080",
"local host:port on which to listen")
flagAudience = flag.String("audience", "",
"override JWT audience value (aud)")
flagToken = flag.String("token", "",
"override OIDC token")
flagPrependUserAgent = flag.Bool("prepend-user-agent", true,
"prepend a custom User-Agent header to requests")
flagServerUpTime = flag.String("server-up-time", "",
"duration the proxy server will run. For example, 1h, 1m30s - empty means forever")
flagHttp2 = flag.Bool("http2", false,
"handle http2 requests (allows grpc calls)")
flagAuthorizationHeader = flag.String("authorization-header", "X-Serverless-Authorization",
"header to provide the bearer token")
)
func main() {
ctx, cancel := signal.NotifyContext(context.Background(),
syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := realMain(ctx); err != nil {
cancel()
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
}
func realMain(ctx context.Context) error {
// Quick handle version and help.
for _, v := range os.Args {
if v == "-v" || v == "-version" || v == "--version" {
fmt.Fprintln(os.Stdout, version.HumanVersion)
return nil
}
}
// Parse and validate flags.
flag.Parse()
var merr error
if *flagHost == "" {
merr = errors.Join(merr, fmt.Errorf("missing -host"))
}
if *flagBind == "" {
merr = errors.Join(merr, fmt.Errorf("missing -bind"))
}
if *flagAuthorizationHeader == "" {
merr = errors.Join(merr, fmt.Errorf("missing -authorization-header"))
}
var d time.Duration
if *flagServerUpTime != "" {
var err error
d, err = time.ParseDuration(*flagServerUpTime)
if err != nil {
merr = errors.Join(merr, fmt.Errorf("error parsing -server-up-time: %w", err))
}
}
if merr != nil {
return merr
}
// Build the remote host URL.
host, err := smartBuildHost(*flagHost)
if err != nil {
return fmt.Errorf("failed to parse host URL: %w", err)
}
// Compute the audience, default to the host. However, there might be cases
// where you want to specify a custom aud (such as when accessing through a
// load balancer).
audience := *flagAudience
if audience == "" {
audience = host.String()
}
// Get the best token source. Cloud Run expects the audience parameter to be
// the URL of the service.
tokenSource, err := findTokenSource(ctx, *flagToken, audience)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}
// Build the local bind URL.
bindHost, bindPort, err := net.SplitHostPort(*flagBind)
if err != nil {
return fmt.Errorf("failed to parse bind address: %w", err)
}
bind := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(bindHost, bindPort),
}
// Construct the proxy.
proxy := buildProxy(host, bind, tokenSource, *flagHttp2, nil)
server := createServer(bind, proxy, *flagHttp2)
// Start server in background.
errCh := make(chan error, 1)
go func() {
fmt.Fprintf(os.Stderr, "%s proxies to %s\n", bind, host)
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
select {
case errCh <- err:
default:
}
}
}()
// Wait for stop
if *flagServerUpTime != "" {
select {
case err := <-errCh:
return fmt.Errorf("server error: %w", err)
case <-time.After(d):
case <-ctx.Done():
fmt.Fprint(os.Stderr, "\nserver is shutting down...\n")
}
} else {
select {
case err := <-errCh:
return fmt.Errorf("server error: %w", err)
case <-ctx.Done():
fmt.Fprint(os.Stderr, "\nserver is shutting down...\n")
}
}
// Attempt graceful shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown server: %w", err)
}
return nil
}
// buildProxy builds the reverse proxy server, forwarding requests on bind to
// the provided host.
func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource, enableHttp2 bool, caCertificate *x509.Certificate) *httputil.ReverseProxy {
// Build and configure the proxy.
proxy := httputil.NewSingleHostReverseProxy(host)
// Use http2 for outgoing connections
if enableHttp2 {
var tlsConfig *tls.Config
if caCertificate != nil {
caPool := x509.NewCertPool()
caPool.AddCert(caCertificate)
tlsConfig = &tls.Config{
RootCAs: caPool,
}
}
proxy.Transport = &http2.Transport{
TLSClientConfig: tlsConfig,
}
}
// Configure the director.
originalDirector := proxy.Director
proxy.Director = func(r *http.Request) {
// Call the original director, which configures most of the URL bits for us.
originalDirector(r)
// Override host - this is not done by the default director, but Cloud Run
// requires it.
r.Header.Set("Host", host.Host)
r.Host = host.Host
ctx := r.Context()
// Get the oauth token.
token, err := tokenSource.Token()
if err != nil {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("failed to get token: %w\n\n%s", err, ADCHintMessage)))
return
}
// Get the id_token.
idToken := token.AccessToken
if idToken == "" {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("missing id_token")))
return
}
// Set a custom user-agent header.
if *flagPrependUserAgent {
ua := r.Header.Get("User-Agent")
if ua == "" {
ua = userAgent
} else {
ua = userAgent + " " + ua
}
r.Header.Set("User-Agent", ua)
}
// Set the bearer token to be the id token
r.Header.Set(*flagAuthorizationHeader, "Bearer "+idToken)
}
// Configure error handling.
proxy.ModifyResponse = func(r *http.Response) error {
// In case of redirection, make sure the local address is still used for
// host. If it has location header && the location url host is the proxied
// host, change it to local address with http.
location := r.Header.Get("Location")
if location != "" {
locationURL, err := url.Parse(location)
if err == nil && locationURL.Host == host.Host {
locationURL.Scheme = bind.Scheme
locationURL.Host = bind.Host
r.Header.Set("Location", locationURL.String())
}
}
ctx := r.Request.Context()
if err, ok := ctx.Value(contextKeyError).(error); ok && err != nil {
return fmt.Errorf("[PROXY ERROR] %w", err)
}
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return proxy
}
// Create server and wraps proxy with h2c handler if http2 is enabled
func createServer(bind *url.URL, proxy *httputil.ReverseProxy, enableHttp2 bool) *http.Server {
var handler http.Handler
handler = proxy
if enableHttp2 {
http2server := &http2.Server{}
handler = h2c.NewHandler(proxy, http2server)
}
// Create server.
return &http.Server{
Addr: bind.Host,
Handler: handler,
}
}
// findTokenSource fetches the reusable/cached oauth2 token source. If rawToken
// is provided, that token is used as a static value and the audience parameter
// is ignored. Othwerise, this attempts to get the renewable token from the
// environment (via Application Default Credentials).
func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) {
// Prefer supplied value, usually from the flag.
if rawToken != "" {
token := &oauth2.Token{AccessToken: rawToken}
return oauth2.StaticTokenSource(token), nil
}
// Try to use the idtoken package, which will use the metadata service.
// However, the idtoken package does not work with gcloud's ADC, so we need to
// handle that case by falling back to default ADC search. However, the
// default ADC has a token at a different path, so we construct a custom token
// source for this edge case.
tokenSource, err := idtoken.NewTokenSource(ctx, audience)
if err != nil {
// Return any unexpected error.
if !strings.Contains(err.Error(), "credential must be service_account") {
return nil, fmt.Errorf("failed to get idtoken source: %w", err)
}
// If we got this far, it means that we found ADC, but the ADC was supplied
// by a gcloud "authorized_user" instead of a service account. Thus we
// fallback to the default ADC search.
tokenSource, err = google.DefaultTokenSource(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get default token source: %w", err)
}
tokenSource = &idTokenFromDefaultTokenSource{TokenSource: tokenSource}
}
return oauth2.ReuseTokenSource(nil, tokenSource), nil
}
type idTokenFromDefaultTokenSource struct {
TokenSource oauth2.TokenSource
}
// Token extracts the id_token field from ADC from a default token source and
// puts the value into the AccessToken field.
func (s *idTokenFromDefaultTokenSource) Token() (*oauth2.Token, error) {
token, err := s.TokenSource.Token()
if err != nil {
return nil, err
}
idToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token")
}
return &oauth2.Token{
AccessToken: idToken,
Expiry: token.Expiry,
}, nil
}
// smartBuildHost parses the URL, handling the case where it's a real URL
// (https://foo.bar) or just a host (foo.bar). If it's just a host, the URL is
// assumed to be TLS.
func smartBuildHost(host string) (*url.URL, error) {
u, err := url.Parse(host)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
}
if u.Scheme == "" {
u.Scheme = "https"
parts := strings.SplitN(u.Path, "/", 2)
switch len(parts) {
case 0:
u.Host = ""
u.Path = ""
case 1:
u.Host = parts[0]
u.Path = ""
case 2:
u.Host = parts[0]
u.Path = parts[1]
}
}
u.Host = strings.TrimSpace(u.Host)
if u.Host == "" {
return nil, fmt.Errorf("invalid url %q (missing host)", host)
}
u.Path = strings.TrimSpace(u.Path)
if u.Path == "/" {
u.RawPath = ""
u.Path = ""
}
return u, nil
}