config/configure.go (469 lines of code) (raw):

// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package config import ( "bufio" "fmt" "github.com/aliyun/aliyun-cli/v3/cloudsso" "github.com/aliyun/aliyun-cli/v3/util" "io" "net/http" "net/url" "os" "strconv" "strings" "time" "github.com/aliyun/aliyun-cli/v3/cli" "github.com/aliyun/aliyun-cli/v3/i18n" ) var hookLoadConfiguration = func(fn func(path string) (*Configuration, error)) func(path string) (*Configuration, error) { return fn } var hookSaveConfiguration = func(fn func(config *Configuration) error) func(config *Configuration) error { return fn } var stdin io.Reader = os.Stdin // 为了方便 mock 的函数变量 var cloudssoGetAccessToken = func(ssoLogin *cloudsso.SsoLogin) (*cloudsso.AccessTokenResponse, error) { return ssoLogin.GetAccessToken() } var cloudssoListAllUsers = func(userParam *cloudsso.ListUserParameter) ([]cloudsso.AccountDetailResponse, error) { return userParam.ListAllUsers() } var cloudssoListAllAccessConfigurations = func(accessParam *cloudsso.AccessConfigurationsParameter, req cloudsso.AccessConfigurationsRequest) ([]cloudsso.AccessConfiguration, error) { return accessParam.ListAllAccessConfigurations(req) } var cloudssoTryRefreshStsToken = func(signInUrl, accessToken, accessConfig, accountId *string, httpClient *http.Client) (*cloudsso.CloudCredentialResponse, error) { return cloudsso.TryRefreshStsToken(signInUrl, accessToken, accessConfig, accountId, httpClient) } var doConfigureProxy = func(ctx *cli.Context, profileName string, mode string) error { return doConfigure(ctx, profileName, mode) } func loadConfiguration() (*Configuration, error) { return hookLoadConfiguration(LoadConfiguration)(GetConfigPath() + "/" + configFile) } func NewConfigureCommand() *cli.Command { c := &cli.Command{ Name: "configure", Short: i18n.T( "configure credential and settings", "配置身份认证和其他信息"), Usage: "configure --mode {AK|RamRoleArn|EcsRamRole|OIDC|External|CredentialsURI|ChainableRamRoleArn|CloudSSO} --profile <profileName>", Run: func(ctx *cli.Context, args []string) error { if len(args) > 0 { return cli.NewInvalidCommandError(args[0], ctx) } profileName, _ := ProfileFlag(ctx.Flags()).GetValue() mode, _ := ModeFlag(ctx.Flags()).GetValue() if mode == "" { // 检查 profileName 是否存在 conf, err := loadConfiguration() if err == nil { if profileName == "" { profileName = conf.CurrentProfile } if profileName != "" { p, ok := conf.GetProfile(profileName) if ok { mode = string(p.Mode) } } } } return doConfigureProxy(ctx, profileName, mode) }, } c.AddSubCommand(NewConfigureGetCommand()) c.AddSubCommand(NewConfigureSetCommand()) c.AddSubCommand(NewConfigureListCommand()) c.AddSubCommand(NewConfigureDeleteCommand()) c.AddSubCommand(NewConfigureSwitchCommand()) return c } func doConfigure(ctx *cli.Context, profileName string, mode string) error { w := ctx.Stdout() conf, err := loadConfiguration() if err != nil { return err } if profileName == "" { if conf.CurrentProfile == "" { profileName = "default" } else { profileName = conf.CurrentProfile originMode := string(conf.GetCurrentProfile(ctx).Mode) if mode == "" { mode = originMode } else if mode != originMode { cli.Printf(w, "Warning: You are changing the authentication type of profile '%s' from '%s' to '%s'\n", profileName, originMode, mode) } } } if mode == "" { mode = "AK" } cp, ok := conf.GetProfile(profileName) if !ok { cp = conf.NewProfile(profileName) } cli.Printf(w, "Configuring profile '%s' in '%s' authenticate mode...\n", profileName, mode) if mode != "" { switch AuthenticateMode(mode) { case AK: cp.Mode = AK configureAK(w, &cp) case StsToken: cp.Mode = StsToken configureStsToken(w, &cp) case RamRoleArn: cp.Mode = RamRoleArn configureRamRoleArn(w, &cp) case EcsRamRole: cp.Mode = EcsRamRole configureEcsRamRole(w, &cp) case RamRoleArnWithEcs: cp.Mode = RamRoleArnWithEcs configureRamRoleArnWithEcs(w, &cp) case ChainableRamRoleArn: cp.Mode = ChainableRamRoleArn configureChainableRamRoleArn(w, &cp) case RsaKeyPair: cp.Mode = RsaKeyPair configureRsaKeyPair(w, &cp) case External: cp.Mode = External configureExternal(w, &cp) case CredentialsURI: cp.Mode = CredentialsURI configureCredentialsURI(w, &cp) case OIDC: cp.Mode = OIDC configureOIDC(w, &cp) case CloudSSO: cp.Mode = CloudSSO // parameter from command has higher priority, use it directly if CloudSSOSignInUrlFlag(ctx.Flags()).IsAssigned() { cp.CloudSSOSignInUrl, _ = CloudSSOSignInUrlFlag(ctx.Flags()).GetValue() } if CloudSSOAccountIdFlag(ctx.Flags()).IsAssigned() { cp.CloudSSOAccountId, _ = CloudSSOAccountIdFlag(ctx.Flags()).GetValue() } if CloudSSOAccessConfigFlag(ctx.Flags()).IsAssigned() { cp.CloudSSOAccessConfig, _ = CloudSSOAccessConfigFlag(ctx.Flags()).GetValue() } err := configureCloudSSO(w, &cp) if err != nil { return err } default: return fmt.Errorf("unexcepted authenticate mode: %s", mode) } } else { configureAK(w, &cp) } // configure common if cp.Mode != CloudSSO || cp.RegionId == "" { cli.Printf(w, "Default Region Id [%s]: ", cp.RegionId) cp.RegionId = ReadInput(cp.RegionId) } if cp.Mode != CloudSSO || cp.OutputFormat == "" { cli.Printf(w, "Default Output Format [%s]: json (Only support json)\n", cp.OutputFormat) // cp.OutputFormat = ReadInput(cp.OutputFormat) cp.OutputFormat = "json" } if cp.Mode != CloudSSO || cp.Language == "" { cli.Printf(w, "Default Language [zh|en] %s: ", cp.Language) cp.Language = ReadInput(cp.Language) if cp.Language != "zh" && cp.Language != "en" { cp.Language = i18n.GetLanguage() } } //fmt.Printf("User site: [china|international|japan] %s", cp.Site) //cp.Site = ReadInput(cp.Site) cli.Printf(w, "Saving profile[%s] ...", profileName) conf.PutProfile(cp) conf.CurrentProfile = cp.Name err = hookSaveConfiguration(SaveConfiguration)(conf) // cp 要在下文的 DoHello 中使用,所以 需要建立 parent 的关系 cp.parent = conf if err != nil { return err } cli.Printf(w, "Done.\n") DoHello(ctx, &cp) return nil } func configureAK(w io.Writer, cp *Profile) error { cli.Printf(w, "Access Key Id [%s]: ", MosaicString(cp.AccessKeyId, 3)) cp.AccessKeyId = ReadInput(cp.AccessKeyId) cli.Printf(w, "Access Key Secret [%s]: ", MosaicString(cp.AccessKeySecret, 3)) cp.AccessKeySecret = ReadInput(cp.AccessKeySecret) return nil } func configureStsToken(w io.Writer, cp *Profile) error { err := configureAK(w, cp) if err != nil { return err } cli.Printf(w, "Sts Token [%s]: ", cp.StsToken) cp.StsToken = ReadInput(cp.StsToken) return nil } func configureRamRoleArn(w io.Writer, cp *Profile) error { err := configureAK(w, cp) if err != nil { return err } cli.Printf(w, "Sts Region [%s]: ", cp.StsRegion) cp.StsRegion = ReadInput(cp.StsRegion) cli.Printf(w, "Ram Role Arn [%s]: ", cp.RamRoleArn) cp.RamRoleArn = ReadInput(cp.RamRoleArn) cli.Printf(w, "Role Session Name [%s]: ", cp.RoleSessionName) cp.RoleSessionName = ReadInput(cp.RoleSessionName) if cp.ExpiredSeconds == 0 { cp.ExpiredSeconds = 900 } cli.Printf(w, "External ID [%s]: ", cp.ExternalId) cp.ExternalId = ReadInput(cp.ExternalId) cli.Printf(w, "Expired Seconds [%v]: ", cp.ExpiredSeconds) cp.ExpiredSeconds, _ = strconv.Atoi(ReadInput(strconv.Itoa(cp.ExpiredSeconds))) return nil } func configureEcsRamRole(w io.Writer, cp *Profile) error { cli.Printf(w, "Ecs Ram Role [%s]: ", cp.RamRoleName) cp.RamRoleName = ReadInput(cp.RamRoleName) return nil } func configureRamRoleArnWithEcs(w io.Writer, cp *Profile) error { cli.Printf(w, "Ecs Ram Role [%s]: ", cp.RamRoleName) cp.RamRoleName = ReadInput(cp.RamRoleName) cli.Printf(w, "Sts Region [%s]: ", cp.StsRegion) cp.StsRegion = ReadInput(cp.StsRegion) cli.Printf(w, "Ram Role Arn [%s]: ", cp.RamRoleArn) cp.RamRoleArn = ReadInput(cp.RamRoleArn) cli.Printf(w, "Role Session Name [%s]: ", cp.RoleSessionName) cp.RoleSessionName = ReadInput(cp.RoleSessionName) if cp.ExpiredSeconds == 0 { cp.ExpiredSeconds = 900 } cli.Printf(w, "Expired Seconds [%v]: ", cp.ExpiredSeconds) cp.ExpiredSeconds, _ = strconv.Atoi(ReadInput(strconv.Itoa(cp.ExpiredSeconds))) return nil } func configureChainableRamRoleArn(w io.Writer, cp *Profile) error { cli.Printf(w, "Source Profile [%s]: ", cp.SourceProfile) cp.SourceProfile = ReadInput(cp.SourceProfile) cli.Printf(w, "Sts Region [%s]: ", cp.StsRegion) cp.StsRegion = ReadInput(cp.StsRegion) cli.Printf(w, "Ram Role Arn [%s]: ", cp.RamRoleArn) cp.RamRoleArn = ReadInput(cp.RamRoleArn) cli.Printf(w, "Role Session Name [%s]: ", cp.RoleSessionName) cp.RoleSessionName = ReadInput(cp.RoleSessionName) if cp.ExpiredSeconds == 0 { cp.ExpiredSeconds = 900 } cli.Printf(w, "External ID [%s]: ", cp.ExternalId) cp.ExternalId = ReadInput(cp.ExternalId) cli.Printf(w, "Expired Seconds [%v]: ", cp.ExpiredSeconds) cp.ExpiredSeconds, _ = strconv.Atoi(ReadInput(strconv.Itoa(cp.ExpiredSeconds))) return nil } func configureRsaKeyPair(w io.Writer, cp *Profile) error { cli.Printf(w, "Rsa Private Key File: ") keyFile := ReadInput("") buf, err := os.ReadFile(keyFile) if err != nil { return fmt.Errorf("read key file %s failed %v", keyFile, err) } cp.PrivateKey = string(buf) cli.Printf(w, "Rsa Key Pair Name: ") cp.KeyPairName = ReadInput("") cp.ExpiredSeconds = 900 return nil } func configureExternal(w io.Writer, cp *Profile) error { cli.Printf(w, "Process Command [%s]: ", cp.ProcessCommand) cp.ProcessCommand = ReadInput(cp.ProcessCommand) return nil } func configureCredentialsURI(w io.Writer, cp *Profile) error { cli.Printf(w, "Credentials URI [%s]: ", cp.CredentialsURI) cp.CredentialsURI = ReadInput(cp.CredentialsURI) return nil } func configureOIDC(w io.Writer, cp *Profile) error { cli.Printf(w, "OIDC Provider ARN [%s]: ", cp.OIDCProviderARN) cp.OIDCProviderARN = ReadInput(cp.OIDCProviderARN) cli.Printf(w, "OIDC Token File [%s]: ", cp.OIDCTokenFile) cp.OIDCTokenFile = ReadInput(cp.OIDCTokenFile) cli.Printf(w, "RAM Role ARN [%s]: ", cp.RamRoleArn) cp.RamRoleArn = ReadInput(cp.RamRoleArn) cli.Printf(w, "Role Session Name [%s]: ", cp.RoleSessionName) cp.RoleSessionName = ReadInput(cp.RoleSessionName) cp.ExpiredSeconds = 3600 return nil } func configureCloudSSO(w io.Writer, cp *Profile) error { cli.Printf(w, "CloudSSO Sign In Url [%s]: ", cp.CloudSSOSignInUrl) userInputCloudSSOSignInUrl := ReadInput(cp.CloudSSOSignInUrl) if userInputCloudSSOSignInUrl != cp.CloudSSOSignInUrl && cp.CloudSSOSignInUrl != "" { // 需要清空其他的字段,完整的走登录 cp.AccessKeyId = "" cp.AccessKeySecret = "" cp.StsToken = "" cp.CloudSSOAccessConfig = "" cp.CloudSSOAccountId = "" cp.CloudSSOSignInUrl = userInputCloudSSOSignInUrl cp.AccessToken = "" cp.StsExpiration = 0 cp.CloudSSOAccessTokenExpire = 0 } else { cp.CloudSSOSignInUrl = userInputCloudSSOSignInUrl } if cp.CloudSSOSignInUrl == "" { return fmt.Errorf("CloudSSOSignInUrl is required") } // start login in, get access token, then list account for choose httpClient := util.NewHttpClient() ssoLogin := cloudsso.SsoLogin{ SignInUrl: cp.CloudSSOSignInUrl, // force login ExpireTime: 0, HttpClient: httpClient, } accessToken, err := cloudssoGetAccessToken(&ssoLogin) if err != nil { return fmt.Errorf("get access token failed: %s", err) } cp.AccessToken = accessToken.AccessToken cp.CloudSSOAccessTokenExpire = util.GetCurrentUnixTime() + int64(accessToken.ExpiresIn) // parse base url baseUrl, err := url.Parse(ssoLogin.SignInUrl) // list account for choose userParameter := cloudsso.ListUserParameter{ AccessToken: cp.AccessToken, BaseUrl: baseUrl.Scheme + "://" + baseUrl.Host, HttpClient: httpClient, } allUser, err := cloudssoListAllUsers(&userParameter) if err != nil { return fmt.Errorf("list account failed: %s", err) } // if allUser is empty, return error if len(allUser) == 0 { return fmt.Errorf("no account found") } accountIdHistory := cp.CloudSSOAccountId if accountIdHistory != "" { // 已经指定了账号,检查是否存在,如果不存在需要继续指定 var exist = false for _, user := range allUser { if user.AccountId == accountIdHistory { exist = true break } } if !exist { cli.Printf(w, "Account %s not found, please choose again\n", accountIdHistory) // clear history cp.CloudSSOAccountId = "" } } if cp.CloudSSOAccountId == "" { // 只有当账户不存在时才需要重新选择 // if allUser has only one account, use it directly if len(allUser) == 1 { cp.CloudSSOAccountId = allUser[0].AccountId cli.Printf(w, "Account: %s\n", allUser[0].DisplayName) } else { // print all user id cli.Println(w, "Please choose an account:") for i, user := range allUser { fmt.Printf("%d. %s\n", i+1, user.DisplayName) } cli.Printf(w, "Please input the account number: ") var accountNumber int // read input input := ReadInput("1") // parse input to int accountNumber, err = strconv.Atoi(input) if err != nil { return fmt.Errorf("invalid account number: %s", err) } if accountNumber < 1 || accountNumber > len(allUser) { return fmt.Errorf("invalid account number") } cp.CloudSSOAccountId = allUser[accountNumber-1].AccountId } } // get access configuration accessConfigurationParameter := cloudsso.AccessConfigurationsParameter{ AccessToken: cp.AccessToken, UrlPrefix: baseUrl.Scheme + "://" + baseUrl.Host, HttpClient: httpClient, AccountId: cp.CloudSSOAccountId, } accessConfigurations, err := cloudssoListAllAccessConfigurations(&accessConfigurationParameter, cloudsso.AccessConfigurationsRequest{ AccountId: cp.CloudSSOAccountId, }) if err != nil { return fmt.Errorf("list access configuration failed: %s", err) } if len(accessConfigurations) == 0 { return fmt.Errorf("no access configuration found") } acHistory := cp.CloudSSOAccessConfig if acHistory != "" { // 判断是否存在 var exist = false for _, accessConfiguration := range accessConfigurations { if accessConfiguration.AccessConfigurationId == acHistory { exist = true break } } if !exist { cli.Printf(w, "Access Configuration %s not found, please choose again\n", acHistory) // clear history cp.CloudSSOAccessConfig = "" } } if cp.CloudSSOAccessConfig == "" { // if accessConfigurations has only one access configuration, use it directly if len(accessConfigurations) == 1 { cp.CloudSSOAccessConfig = accessConfigurations[0].AccessConfigurationId cli.Printf(w, "Access Configuration: %s\n", accessConfigurations[0].AccessConfigurationId) } else { // print all access configuration id cli.Println(w, "Please choose an access configuration:") for i, accessConfiguration := range accessConfigurations { cli.Printf(w, "%d. %s\n", i+1, accessConfiguration.AccessConfigurationName) } cli.Printf(w, "Please input the access configuration number: ") var accessConfigurationNumber int // read input input := ReadInput("1") // parse input to int accessConfigurationNumber, err = strconv.Atoi(input) if err != nil { return fmt.Errorf("invalid access configuration number: %s", err) } if accessConfigurationNumber < 1 || accessConfigurationNumber > len(accessConfigurations) { return fmt.Errorf("invalid access configuration number") } cp.CloudSSOAccessConfig = accessConfigurations[accessConfigurationNumber-1].AccessConfigurationId } } // create sts token stsInfo, err := cloudssoTryRefreshStsToken(&cp.CloudSSOSignInUrl, &cp.AccessToken, &cp.CloudSSOAccessConfig, &cp.CloudSSOAccountId, httpClient) if err != nil { return fmt.Errorf("create sts token failed: %s", err) } cp.AccessKeyId = stsInfo.AccessKeyId cp.AccessKeySecret = stsInfo.AccessKeySecret cp.StsToken = stsInfo.SecurityToken // Expiration is UTC time, 2015-04-09T11:52:19Z, convert to int // Parse the time string parsedTime, err := time.Parse(time.RFC3339, stsInfo.Expiration) if err != nil { return fmt.Errorf("parse expiration time failed: %s", err) } // Convert to Unix time (int64) unixTime := parsedTime.Unix() cp.StsExpiration = unixTime - 5 return nil } func ReadInput(defaultValue string) string { var s string scanner := bufio.NewScanner(stdin) if scanner.Scan() { s = scanner.Text() } if s == "" { return defaultValue } return strings.TrimSpace(s) } func MosaicString(s string, lastChars int) string { r := len(s) - lastChars if r > 0 { return strings.Repeat("*", r) + s[r:] } else { return strings.Repeat("*", len(s)) } } func GetLastChars(s string, lastChars int) string { r := len(s) - lastChars if r > 0 { return s[r:] } else { return strings.Repeat("*", len(s)) } }