tools/service-account-provider/main.go (365 lines of code) (raw):

// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "bytes" "context" "fmt" "io/ioutil" "log" "net/http" "net/url" "os" "strconv" "strings" "time" "encoding/json" "github.com/s12v/go-jwks" "github.com/square/go-jose" "cloud.google.com/go/storage" "github.com/form3tech-oss/jwt-go" "golang.org/x/oauth2" "google.golang.org/api/impersonate" "google.golang.org/api/iterator" "gopkg.in/yaml.v3" ) type Branch struct { Ref string `yaml:"ref"` ServiceAccounts []string `yaml:"service_accounts"` AllowedScopes []string `yaml:"allowed_scopes"` } type Pipeline struct { Name string `yaml:"name"` DefaultServiceAccount string `yaml:"default_sa"` Branches []Branch `yaml:"branches"` } type Config struct { Issuers []struct { Name string `yaml:"name"` JwksUrl string `yaml:"jwks_url"` } `yaml:"issuers"` Pipelines []Pipeline `yaml:"pipelines"` } var config *Config func main() { var err error var port int64 if os.Getenv("PORT") != "" { port, err = strconv.ParseInt(os.Getenv("PORT"), 10, 64) if err != nil { log.Panicf("Can't parse value of PORT env variable %s %v", os.Getenv("PORT"), err) } } else { port = 8080 } err = getConfig() if err != nil { fmt.Printf("Unable to read config %v", err) log.Panicf("Unable to read config %v", err) } go scheduleConfigRefresh() http.HandleFunc("/access", handler) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil)) } func handler(w http.ResponseWriter, r *http.Request) { sa := r.URL.Query().Get("sa") scopesString := r.URL.Query().Get("scopes") if scopesString == "" { scopesString = "https://www.googleapis.com/auth/cloud-platform" } scopes := strings.Split(scopesString, ",") var err error var lifetime float64 lifetimeString := r.URL.Query().Get("lifetime") if lifetimeString == "" { lifetime = 60 } else { lifetime, err = strconv.ParseFloat(lifetimeString, 32) if err != nil { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "%v", err) return } } jwt := r.Header.Get("Gitlab-Token") if jwt != "" { claims, err := validateJwtAndExtractClaims(jwt) if err != nil { w.WriteHeader(http.StatusUnauthorized) fmt.Fprintf(w, "The provided jwt is invalid %v", err) return } if sa == "" { pipeline, err := retrievePipeline(claims) if err != nil { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, "no default SA for the project path defined %s, provide one via sa query attribute", claims.ProjectPath) return } sa = pipeline.DefaultServiceAccount if sa == "" { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, "no default SA for the project path defined %s, provide one via sa query attribute", claims.ProjectPath) return } } if hasAccess(sa, claims) { token, err := getAccesToken(sa, scopes, time.Duration(lifetime)*time.Minute) if err != nil { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "%v", err) } else { fmt.Fprintf(w, "%s", token.AccessToken) } } else { w.WriteHeader(http.StatusForbidden) fmt.Fprintf(w, "JWT not allowed to access sa") } } else { w.WriteHeader(http.StatusUnauthorized) fmt.Fprintf(w, "No jwt provided") } } func retrievePipeline(claims *GitlabClaims) (*Pipeline, error) { current := -1 var currentPipeline *Pipeline for i := 0; i < len(config.Pipelines); i++ { if strings.HasPrefix(claims.ProjectPath, config.Pipelines[i].Name) { if current < len(config.Pipelines[i].Name) { current = len(config.Pipelines[i].Name) currentPipeline = &config.Pipelines[i] } } } if current == -1 { return nil, fmt.Errorf("no pipeline for project path %s", claims.ProjectPath) } else { return currentPipeline, nil } } func hasAccess(serviceAccount string, claims *GitlabClaims) bool { for i := 0; i < len(config.Pipelines); i++ { if strings.HasPrefix(claims.ProjectPath, config.Pipelines[i].Name) { for j := 0; j < len(config.Pipelines[i].Branches); j++ { if config.Pipelines[i].Branches[j].Ref == "*" || config.Pipelines[i].Branches[j].Ref == claims.Ref { if arrayContains(config.Pipelines[i].Branches[j].ServiceAccounts, serviceAccount) { return true } } } } } return false } func arrayContains(array []string, element string) bool { for _, ele := range array { if ele == element { return true } } return false } type GitlabClaims struct { NamespaceId string `json:"namespace_id"` NamespacePath string `json:"namespace_path"` ProjectId string `json:"project_id"` ProjectPath string `json:"project_path"` UserId string `json:"user_id"` UserLogin string `json:"user_login"` UserEmail string `json:"user_email"` PipelineId string `json:"pipeline_id"` JobId string `json:"job_id"` Ref string `json:"ref"` RefType string `json:"ref_type"` RefProtected string `json:"ref_protected"` jwt.StandardClaims } func validateJwtAndExtractClaims(tokenString string) (*GitlabClaims, error) { log.Printf("token %v", tokenString) var claims jwt.Claims = &GitlabClaims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { claims, ok := token.Claims.(*GitlabClaims) if claims.RefProtected != "true" { return nil, fmt.Errorf("unprotected refs are not allowed to access service accounts ref=%s", claims.Ref) } if ok { for i := 0; len(config.Issuers) > i; i++ { if claims.Issuer == config.Issuers[i].Name { var jwksClient jwks.JWKSClient if strings.HasPrefix(config.Issuers[i].JwksUrl, "http") { jwksSource := jwks.NewWebSource(config.Issuers[i].JwksUrl) jwksClient = jwks.NewDefaultClient( jwksSource, time.Hour, // Refresh keys every 1 hour 12*time.Hour, // Expire keys after 12 hours ) } else { jwksSource := NewFileSource(config.Issuers[i].JwksUrl) jwksClient = jwks.NewDefaultClient( jwksSource, time.Hour, // Refresh keys every 1 hour 12*time.Hour, // Expire keys after 12 hours ) } var jwk *jose.JSONWebKey kid := token.Header["kid"] kidString := fmt.Sprintf("%v", kid) jwk, err := jwksClient.GetEncryptionKey(kidString) if err != nil { log.Fatal(err) } return jwk.Public().Key, nil } } return nil, fmt.Errorf("issuer not configured %s", claims.Issuer) } return nil, fmt.Errorf("token not okay %s", tokenString) }) if claims, ok := token.Claims.(*GitlabClaims); ok && token.Valid { return claims, nil } else { return nil, err } } func getAccesToken(sa string, scopes []string, lifetime time.Duration) (*oauth2.Token, error) { ctx := context.Background() ts, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{ TargetPrincipal: sa, Scopes: scopes, Lifetime: lifetime, }) if err != nil { return nil, err } var token *oauth2.Token token, err = ts.Token() if err != nil { return nil, err } return token, nil } func getConfig() error { config = &Config{} configFileLocation := os.Getenv("GCS_CONFIG_LINK") // gs://bucketname/objectname if configFileLocation == "" { return fmt.Errorf("Environment Variable GCS_CONFIG_LINK not set") } u, err := url.Parse(configFileLocation) if err != nil { return fmt.Errorf("Environment Variable GCS_CONFIG_LINK not properly formated %v", err) } if u.Scheme == "gs" { ctx := context.Background() client, err := storage.NewClient(ctx) if err != nil { log.Fatal(err) } defer client.Close() query := &storage.Query{Prefix: u.Path[1:]} bucket := client.Bucket(u.Host) it := bucket.Objects(ctx, query) for { attrs, err := it.Next() if err == iterator.Done { break } if err != nil { log.Fatal(err) } log.Printf("Loading config %s", attrs.Name) rc, err := bucket.Object(attrs.Name).NewReader(ctx) if err != nil { return fmt.Errorf("Object(%q).NewReader: %v", attrs.Name, err) } defer rc.Close() dat, err := ioutil.ReadAll(rc) if err != nil { return fmt.Errorf("ioutil.ReadAll: %v", err) } var singleConfig Config err = yaml.Unmarshal(dat, &singleConfig) if err != nil { log.Fatalf("error parsing config file %s %v", attrs.Name, err) } mergeConfig(config, &singleConfig) //names = append(names, attrs.Name) } return nil } else { return fmt.Errorf("Schema of provided location unsupported (currently only gs://)") } } func mergeConfig(target *Config, source *Config) { target.Issuers = append(target.Issuers, source.Issuers...) for _, sourcePipeline := range source.Pipelines { merged := false for _, targetPipeline := range target.Pipelines { if targetPipeline.Name == sourcePipeline.Name { mergePipelines(&targetPipeline, &sourcePipeline) merged = true } } if !merged { target.Pipelines = append(target.Pipelines, sourcePipeline) } } } func mergePipelines(target *Pipeline, source *Pipeline) { target.DefaultServiceAccount = source.DefaultServiceAccount for _, sourceBranch := range source.Branches { merged := false for _, targetBranch := range target.Branches { if targetBranch.Ref == sourceBranch.Ref { mergeBranches(&targetBranch, &sourceBranch) merged = true } } if !merged { target.Branches = append(target.Branches, sourceBranch) } } } func mergeBranches(target *Branch, source *Branch) { target.AllowedScopes = append(target.AllowedScopes, source.AllowedScopes...) target.ServiceAccounts = append(target.ServiceAccounts, source.ServiceAccounts...) } /* Refresh */ func scheduleConfigRefresh() { var err error configRefreshIntervalString := os.Getenv("CONFIG_REFRESH_INTERVAL") if configRefreshIntervalString == "0" || configRefreshIntervalString == "-1" { return } else { var configRefreshInterval int64 = 5 if configRefreshIntervalString != "" { configRefreshInterval, err = strconv.ParseInt(configRefreshIntervalString, 0, 64) if err != nil { log.Panicf("Unable to parse config refresh interval %v", err) } } ticker := time.NewTicker(time.Duration(configRefreshInterval) * time.Minute) for { select { case <-ticker.C: getConfig() } } } } func NewFileSource(filePath string) *FileSource { return &FileSource{ FilePath: filePath, } } type FileSource struct { FilePath string } func (s *FileSource) JSONWebKeySet() (*jose.JSONWebKeySet, error) { dat, err := ioutil.ReadFile(s.FilePath) if err != nil { return nil, err } jsonWebKeySet := new(jose.JSONWebKeySet) if err = json.NewDecoder(bytes.NewReader(dat)).Decode(jsonWebKeySet); err != nil { return nil, err } return jsonWebKeySet, err }