pkg/provider/onelogin/onelogin.go (282 lines of code) (raw):

package onelogin import ( "bytes" "encoding/json" "fmt" "io/ioutil" "log" "net/http" "net/url" "strings" "time" "github.com/aliyun/saml2alibabacloud/pkg/cfg" "github.com/aliyun/saml2alibabacloud/pkg/creds" "github.com/aliyun/saml2alibabacloud/pkg/prompter" "github.com/aliyun/saml2alibabacloud/pkg/provider" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) // MFA identifier constants. const ( IdentifierOneLoginProtectMfa = "OneLogin Protect" IdentifierSmsMfa = "OneLogin SMS" IdentifierTotpMfa = "Google Authenticator" IdentifierYubiKey = "Yubico YubiKey" MessageMFARequired = "MFA is required for this user" MessageSuccess = "Success" ) // ProviderName constant holds the name of the OneLogin IDP. const ProviderName = "OneLogin" var logger = logrus.WithField("provider", ProviderName) var ( supportedMfaOptions = map[string]string{ IdentifierOneLoginProtectMfa: "OLP", IdentifierSmsMfa: "SMS", IdentifierTotpMfa: "TOTP", IdentifierYubiKey: "YUBIKEY", } ) // Client is a wrapper representing a OneLogin SAML client. type Client struct { // AppID represents the OneLogin connector id. AppID string // Client is the HTTP client for accessing the IDP provider's APIs. Client *provider.HTTPClient // A predefined MFA name. MFA string // Subdomain is the organisation subdomain in OneLogin. Subdomain string } // AuthRequest represents an mfa OneLogin request. type AuthRequest struct { AppID string `json:"app_id"` Password string `json:"password"` Subdomain string `json:"subdomain"` Username string `json:"username_or_email"` IPAddress string `json:"ip_address,omitempty"` } // VerifyRequest represents an mfa verify request type VerifyRequest struct { AppID string `json:"app_id"` DeviceID string `json:"device_id"` DoNotNotify bool `json:"do_not_notify"` OTPToken string `json:"otp_token,omitempty"` StateToken string `json:"state_token"` } // New creates a new OneLogin client. func New(idpAccount *cfg.IDPAccount) (*Client, error) { tr := provider.NewDefaultTransport(idpAccount.SkipVerify) client, err := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(idpAccount)) if err != nil { return nil, errors.Wrap(err, "error building http client") } return &Client{AppID: idpAccount.AppID, Client: client, MFA: idpAccount.MFA, Subdomain: idpAccount.Subdomain}, nil } // Authenticate logs into OneLogin and returns a SAML response. func (c *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { providerURL, err := url.Parse(loginDetails.URL) if err != nil { return "", errors.Wrap(err, "error building providerURL") } host := providerURL.Host logger.Debug("Generating OneLogin access token") // request oAuth token required for working with OneLogin APIs oauthToken, err := generateToken(c, loginDetails, host) if err != nil { return "", errors.Wrap(err, "failed to generate oauth token") } logger.Debug("Retrieved OneLogin OAuth token:", oauthToken) authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password, AppID: c.AppID, Subdomain: c.Subdomain} var authBody bytes.Buffer err = json.NewEncoder(&authBody).Encode(authReq) if err != nil { return "", errors.Wrap(err, "error encoding authreq") } authSubmitURL := fmt.Sprintf("https://%s/api/2/saml_assertion", host) req, err := http.NewRequest("POST", authSubmitURL, &authBody) if err != nil { return "", errors.Wrap(err, "error building authentication request") } addContentHeaders(req) addAuthHeader(req, oauthToken) logger.Debug("Requesting SAML Assertion") // request the SAML assertion. For more details check https://developers.onelogin.com/api-docs/2/saml-assertions/generate-saml-assertion res, err := c.Client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving auth response") } defer res.Body.Close() body, err := ioutil.ReadAll(res.Body) if err != nil { return "", errors.Wrap(err, "error retrieving body from response") } resp := string(body) logger.Debug("SAML Assertion response code:", res.StatusCode) logger.Debug("SAML Assertion response body:", resp) authMessage := gjson.Get(resp, "message").String() if res.StatusCode != 200 { return "", fmt.Errorf("HTTP %v: %s", res.StatusCode, authMessage) } authData := gjson.Get(resp, "data") var samlAssertion string switch authMessage { // MFA not required case MessageSuccess: if authData.IsArray() { return "", errors.New("invalid SAML assertion returned") } samlAssertion = authData.String() case MessageMFARequired: logger.Debug("Verifying MFA") samlAssertion, err = verifyMFA(c, oauthToken, c.AppID, resp) if err != nil { return "", errors.Wrap(err, "error verifying MFA") } default: return "", errors.New("unexpected SAML assertion response") } return samlAssertion, nil } // generateToken is used to generate access token for all OneLogin APIs. // For more infor read https://developers.onelogin.com/api-docs/2/oauth20-tokens/generate-tokens-2 func generateToken(oc *Client, loginDetails *creds.LoginDetails, host string) (string, error) { oauthTokenURL := fmt.Sprintf("https://%s/auth/oauth2/v2/token", host) req, err := http.NewRequest("POST", oauthTokenURL, strings.NewReader(`{"grant_type":"client_credentials"}`)) if err != nil { return "", errors.Wrap(err, "error building oauth token request") } addContentHeaders(req) req.SetBasicAuth(loginDetails.ClientID, loginDetails.ClientSecret) res, err := oc.Client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving oauth token response") } body, err := ioutil.ReadAll(res.Body) if err != nil { return "", errors.Wrap(err, "error reading oauth token response") } defer res.Body.Close() return gjson.Get(string(body), "access_token").String(), nil } func addAuthHeader(r *http.Request, oauthToken string) { r.Header.Add("Authorization", "bearer: "+oauthToken) } func addContentHeaders(r *http.Request) { r.Header.Add("Content-Type", "application/json") r.Header.Add("Accept", "application/json") } // verifyMFA is used to either prompt to user for one time password or request approval using push notification. // For more details check https://developers.onelogin.com/api-docs/2/saml-assertions/verify-factor func verifyMFA(oc *Client, oauthToken, appID, resp string) (string, error) { stateToken := gjson.Get(resp, "state_token").String() // choose an mfa option if there are multiple enabled var option int var mfaOptions []string var preselected bool for n, id := range gjson.Get(resp, "devices.#.device_type").Array() { identifier := id.String() if val, ok := supportedMfaOptions[identifier]; ok { mfaOptions = append(mfaOptions, val) // If there is pre-selected MFA option (thorugh the --mfa flag), then set MFA option index and break early. if val == oc.MFA { option = n preselected = true break } } else { mfaOptions = append(mfaOptions, "UNSUPPORTED: "+identifier) } } if !preselected && len(mfaOptions) > 1 { option = prompter.Choose("Select which MFA option to use", mfaOptions) } factorID := gjson.Get(resp, fmt.Sprintf("devices.%d.device_id", option)).String() callbackURL := gjson.Get(resp, "callback_url").String() mfaIdentifer := gjson.Get(resp, fmt.Sprintf("devices.%d.device_type", option)).String() mfaDeviceID := gjson.Get(resp, fmt.Sprintf("devices.%d.device_id", option)).String() logger.WithField("factorID", factorID).WithField("callbackURL", callbackURL).WithField("mfaIdentifer", mfaIdentifer).Debug("MFA") if _, ok := supportedMfaOptions[mfaIdentifer]; !ok { return "", errors.New("unsupported mfa provider") } switch mfaIdentifer { // These MFA options doesn't need additional request (e.g. to send SMS or a push notification etc) since the user can generate the code using their MFA app of choice. case IdentifierTotpMfa, IdentifierYubiKey: break default: var verifyBody bytes.Buffer err := json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken}) if err != nil { return "", errors.Wrap(err, "error encoding verifyReq") } req, err := http.NewRequest("POST", callbackURL, &verifyBody) if err != nil { return "", errors.Wrap(err, "error building verify request") } addContentHeaders(req) addAuthHeader(req, oauthToken) res, err := oc.Client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving verify response") } body, err := ioutil.ReadAll(res.Body) if err != nil { return "", errors.Wrap(err, "error retrieving body from response") } resp = string(body) if gjson.Get(resp, "status.error").Bool() { msg := gjson.Get(resp, "status.message").String() return "", errors.New(msg) } } switch mfaIdentifer { case IdentifierSmsMfa, IdentifierTotpMfa, IdentifierYubiKey: verifyCode := prompter.StringRequired("Enter verification code") var verifyBody bytes.Buffer json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken, OTPToken: verifyCode}) req, err := http.NewRequest("POST", callbackURL, &verifyBody) if err != nil { return "", errors.Wrap(err, "error building token post request") } addContentHeaders(req) addAuthHeader(req, oauthToken) res, err := oc.Client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving token post response") } body, err := ioutil.ReadAll(res.Body) if err != nil { return "", errors.Wrap(err, "error retrieving body from response") } resp = string(body) message := gjson.Get(resp, "message").String() if res.StatusCode != 200 || message != MessageSuccess { return "", fmt.Errorf("HTTP %v: %s", res.StatusCode, message) } return gjson.Get(resp, "data").String(), nil case IdentifierOneLoginProtectMfa: // set the body payload to disable further push notifications (i.e. set do_not_notify to true) // https://developers.onelogin.com/api-docs/2/saml-assertions/verify-factor var verifyBody bytes.Buffer err := json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, DoNotNotify: true, StateToken: stateToken}) if err != nil { return "", errors.New("error encoding verify MFA request body") } req, err := http.NewRequest("POST", callbackURL, &verifyBody) if err != nil { return "", errors.Wrap(err, "error building token post request") } addContentHeaders(req) addAuthHeader(req, oauthToken) fmt.Printf("\nWaiting for approval, please check your OneLogin Protect app ...") started := time.Now() // loop until success, error, or timeout for { if time.Since(started) > time.Minute { log.Println(" Timeout") return "", errors.New("User did not accept MFA in time") } logger.Debug("Verifying with OneLogin Protect") res, err := oc.Client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving verify response") } body, err := ioutil.ReadAll(res.Body) if err != nil { return "", errors.Wrap(err, "error retrieving body from response") } message := gjson.Get(string(body), "message").String() // on 'error' status if res.StatusCode != 200 { return "", fmt.Errorf("HTTP %v: %s", res.StatusCode, message) } switch true { case strings.Contains(message, "Authentication pending"): time.Sleep(time.Second) fmt.Print(".") case message == MessageSuccess: log.Println(" Approved") return gjson.Get(string(body), "data").String(), nil default: log.Println(" Error:") return "", errors.New("unsupported response from OneLogin, please raise ticket with saml2alibabacloud") } } } // catch all return "", errors.New("no mfa options provided") }