aws_signing_helper/credentials.go (144 lines of code) (raw):

package aws_signing_helper import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "log" "net/http" "runtime" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" ) type CredentialsOpts struct { PrivateKeyId string CertificateId string CertificateBundleId string CertIdentifier CertIdentifier UseLatestExpiringCertificate bool RoleArn string ProfileArnStr string TrustAnchorArnStr string SessionDuration int Region string Endpoint string NoVerifySSL bool WithProxy bool Debug bool Version string LibPkcs11 string ReusePin bool TpmKeyPassword string NoTpmKeyPassword bool ServerTTL int RoleSessionName string } // Middleware to set a custom user agent header func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddleware { return middleware.BuildMiddlewareFunc("UserAgent", func( ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler, ) (middleware.BuildOutput, middleware.Metadata, error) { if req, ok := input.Request.(*smithyhttp.Request); ok { req.Header.Set("User-Agent", userAgent) } return next.HandleBuild(ctx, input) }) } // Function to create session and generate credentials func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) { // Assign values to region and endpoint if they haven't already been assigned trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr) if err != nil { return CredentialProcessOutput{}, err } profileArn, err := arn.Parse(opts.ProfileArnStr) if err != nil { return CredentialProcessOutput{}, err } if trustAnchorArn.Region != profileArn.Region { return CredentialProcessOutput{}, errors.New("trust anchor and profile regions don't match") } if opts.Region == "" { opts.Region = trustAnchorArn.Region } var logMode aws.ClientLogMode = 0 if Debug { logMode = aws.LogSigning | aws.LogRetries | aws.LogRequestWithBody | aws.LogResponseWithBody | aws.LogRequestEventMessage | aws.LogResponseEventMessage } // Custom HTTP client with proxy and TLS settings var tr *http.Transport if opts.WithProxy { tr = &http.Transport{ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL}, Proxy: http.ProxyFromEnvironment, } } else { tr = &http.Transport{ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL}, } } httpClient := &http.Client{Transport: tr} ctx := context.TODO() cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(opts.Region), config.WithHTTPClient(httpClient), config.WithClientLogMode(logMode)) if err != nil { return CredentialProcessOutput{}, err } // Override endpoint if specified if opts.Endpoint != "" { cfg.BaseEndpoint = aws.String(opts.Endpoint) } // Set a custom user agent userAgentStr := fmt.Sprintf("CredHelper/%s (%s; %s; %s)", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH) cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error { stack.Build.Remove("UserAgent") return stack.Build.Add(createCredHelperUserAgentMiddleware(userAgentStr), middleware.After) }) // Add custom request signer, implementing SigV4-X509 certificate, err := signer.Certificate() if err != nil { return CredentialProcessOutput{}, errors.New("unable to find certificate") } certificateChain, err := signer.CertificateChain() if err != nil { // If the chain couldn't be found, don't include it in the request if Debug { log.Println(err) } } cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error { // Remove middleware related to SigV4 signing stack.Finalize.Remove("Signing") stack.Finalize.Remove("setLegacyContextSigningOptions") stack.Finalize.Remove("GetIdentity") // Add middleware for SigV4-X509 signing stack.Finalize.Add(middleware.FinalizeMiddlewareFunc("Signing", CreateRequestSignFinalizeFunction(signer, opts.Region, signatureAlgorithm, certificate, certificateChain)), middleware.After) return nil }) // Create the Roles Anywhere client using the above-constructed Config rolesAnywhereClient := rolesanywhere.NewFromConfig(cfg) certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw) durationSeconds := int32(opts.SessionDuration) createSessionRequest := rolesanywhere.CreateSessionInput{ Cert: &certificateStr, ProfileArn: &opts.ProfileArnStr, TrustAnchorArn: &opts.TrustAnchorArnStr, DurationSeconds: &(durationSeconds), InstanceProperties: nil, RoleArn: &opts.RoleArn, SessionName: nil, } if opts.RoleSessionName != "" { createSessionRequest.RoleSessionName = &opts.RoleSessionName } output, err := rolesAnywhereClient.CreateSession(ctx, &createSessionRequest) if err != nil { return CredentialProcessOutput{}, err } if len(output.CredentialSet) == 0 { msg := "unable to obtain temporary security credentials from CreateSession" return CredentialProcessOutput{}, errors.New(msg) } credentials := output.CredentialSet[0].Credentials credentialProcessOutput := CredentialProcessOutput{ Version: 1, AccessKeyId: *credentials.AccessKeyId, SecretAccessKey: *credentials.SecretAccessKey, SessionToken: *credentials.SessionToken, Expiration: *credentials.Expiration, } return credentialProcessOutput, nil }