pkg/provider/pingfed/pingfed.go (213 lines of code) (raw):
package pingfed
import (
"context"
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/aliyun/saml2alibabacloud/pkg/cfg"
"github.com/aliyun/saml2alibabacloud/pkg/creds"
"github.com/aliyun/saml2alibabacloud/pkg/page"
"github.com/aliyun/saml2alibabacloud/pkg/prompter"
"github.com/aliyun/saml2alibabacloud/pkg/provider"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
var logger = logrus.WithField("provider", "pingfed")
// Client wrapper around PingFed + PingId enabling authentication and retrieval of assertions
type Client struct {
client *provider.HTTPClient
idpAccount *cfg.IDPAccount
}
// New create a new PingFed client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {
tr := provider.NewDefaultTransport(idpAccount.SkipVerify)
client, err := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(idpAccount))
if err != nil {
return nil, errors.Wrap(err, "error building http client")
}
// assign a response validator to ensure all responses are either success or a redirect
// this is to avoid have explicit checks for every single response
client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator
return &Client{
client: client,
idpAccount: idpAccount,
}, nil
}
type ctxKey string
// Authenticate Authenticate to PingFed and return the data from the body of the SAML assertion.
func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
u := fmt.Sprintf("%s/idp/startSSO.ping?PartnerSpId=%s", loginDetails.URL, ac.idpAccount.AlibabaCloudURN)
req, err := http.NewRequest("GET", u, nil)
if err != nil {
return "", errors.Wrap(err, "error building request")
}
ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails)
return ac.follow(ctx, req)
}
func (ac *Client) follow(ctx context.Context, req *http.Request) (string, error) {
res, err := ac.client.Do(req)
if err != nil {
return "", errors.Wrap(err, "error following")
}
doc, err := goquery.NewDocumentFromResponse(res)
if err != nil {
return "", errors.Wrap(err, "failed to build document from response")
}
var handler func(context.Context, *goquery.Document) (context.Context, *http.Request, error)
if docIsFormRedirectToAlibabaCloud(doc) {
logger.WithField("type", "saml-response").Debug("doc detect")
if samlResponse, ok := extractSAMLResponse(doc); ok {
decodedSamlResponse, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
return "", errors.Wrap(err, "failed to decode saml-response")
}
logger.WithField("type", "saml-response").WithField("saml-response", string(decodedSamlResponse)).Debug("doc detect")
return samlResponse, nil
}
} else if docIsFormSamlRequest(doc) {
logger.WithField("type", "saml-request").Debug("doc detect")
handler = ac.handleFormRedirect
} else if docIsFormResume(doc) {
logger.WithField("type", "resume").Debug("doc detect")
handler = ac.handleFormRedirect
} else if docIsFormSamlResponse(doc) {
logger.WithField("type", "saml-response").Debug("doc detect")
handler = ac.handleFormRedirect
} else if docIsLogin(doc) {
logger.WithField("type", "login").Debug("doc detect")
handler = ac.handleLogin
} else if docIsOTP(doc) {
logger.WithField("type", "otp").Debug("doc detect")
handler = ac.handleOTP
} else if docIsSwipe(doc) {
logger.WithField("type", "swipe").Debug("doc detect")
handler = ac.handleSwipe
} else if docIsFormRedirect(doc) {
logger.WithField("type", "form-redirect").Debug("doc detect")
handler = ac.handleFormRedirect
} else if docIsWebAuthn(doc) {
logger.WithField("type", "webauthn").Debug("doc detect")
handler = ac.handleWebAuthn
}
if handler == nil {
html, _ := doc.Selection.Html()
logger.WithField("doc", html).Debug("Unknown document type")
return "", fmt.Errorf("Unknown document type")
}
ctx, req, err = handler(ctx, doc)
if err != nil {
return "", err
}
return ac.follow(ctx, req)
}
func (ac *Client) handleLogin(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
loginDetails, ok := ctx.Value(ctxKey("login")).(*creds.LoginDetails)
if !ok {
return ctx, nil, fmt.Errorf("no context value for 'login'")
}
form, err := page.NewFormFromDocument(doc, "form")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting login form")
}
form.Values.Set("pf.username", loginDetails.Username)
form.Values.Set("pf.pass", loginDetails.Password)
form.URL = makeAbsoluteURL(form.URL, loginDetails.URL)
req, err := form.BuildRequest()
return ctx, req, err
}
func (ac *Client) handleOTP(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
form, err := page.NewFormFromDocument(doc, "#otp-form")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting OTP form")
}
token := prompter.StringRequired("Enter passcode")
form.Values.Set("otp", token)
req, err := form.BuildRequest()
return ctx, req, err
}
func (ac *Client) handleSwipe(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
form, err := page.NewFormFromDocument(doc, "#form1")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting swipe status form")
}
// poll status. request must specifically be a GET
form.Method = "GET"
req, err := form.BuildRequest()
if err != nil {
return ctx, nil, err
}
for {
time.Sleep(3 * time.Second)
res, err := ac.client.Do(req)
if err != nil {
return ctx, nil, errors.Wrap(err, "error polling swipe status")
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return ctx, nil, errors.Wrap(err, "error parsing body from swipe status response")
}
resp := string(body)
pingfedMFAStatusResponse := gjson.Get(resp, "status").String()
//ASYNC_AUTH_WAIT indicates we keep going
//OK indicates someone swiped
//DEVICE_CLAIM_TIMEOUT indicates nobody swiped
//otherwise loop forever?
if pingfedMFAStatusResponse == "OK" || pingfedMFAStatusResponse == "DEVICE_CLAIM_TIMEOUT" || pingfedMFAStatusResponse == "TIMEOUT" {
break
}
}
// now build a request for getting response of MFA
form, err = page.NewFormFromDocument(doc, "#reponseView")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting swipe response form")
}
req, err = form.BuildRequest()
return ctx, req, err
}
func (ac *Client) handleFormRedirect(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
form, err := page.NewFormFromDocument(doc, "")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting redirect form")
}
req, err := form.BuildRequest()
return ctx, req, err
}
func (ac *Client) handleWebAuthn(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
form, err := page.NewFormFromDocument(doc, "")
if err != nil {
return ctx, nil, errors.Wrap(err, "error extracting webauthn form")
}
form.Values.Set("isWebAuthnSupportedByBrowser", "false")
req, err := form.BuildRequest()
return ctx, req, err
}
func docIsLogin(doc *goquery.Document) bool {
return doc.Has("input[name=\"pf.pass\"]").Size() == 1
}
func docIsOTP(doc *goquery.Document) bool {
return doc.Has("form#otp-form").Size() == 1
}
func docIsSwipe(doc *goquery.Document) bool {
return doc.Has("form#form1").Size() == 1 && doc.Has("form#reponseView").Size() == 1
}
func docIsFormRedirect(doc *goquery.Document) bool {
return doc.Has("input[name=\"ppm_request\"]").Size() == 1
}
func docIsWebAuthn(doc *goquery.Document) bool {
return doc.Has("input[name=\"isWebAuthnSupportedByBrowser\"]").Size() == 1
}
func docIsFormSamlRequest(doc *goquery.Document) bool {
return doc.Find("input[name=\"SAMLRequest\"]").Size() == 1
}
func docIsFormSamlResponse(doc *goquery.Document) bool {
return doc.Find("input[name=\"SAMLResponse\"]").Size() == 1
}
func docIsFormResume(doc *goquery.Document) bool {
return doc.Find("input[name=\"RelayState\"]").Size() == 1
}
func docIsFormRedirectToAlibabaCloud(doc *goquery.Document) bool {
return doc.Find("form[action=\"https://signin.aliyun.com/saml-role/sso\"]").Size() == 1
}
func extractSAMLResponse(doc *goquery.Document) (v string, ok bool) {
return doc.Find("input[name=\"SAMLResponse\"]").Attr("value")
}
// ensures given url is an absolute URL. if not, it will be combined with the base URL
func makeAbsoluteURL(v string, base string) string {
if u, err := url.ParseRequestURI(v); err == nil && !u.IsAbs() {
return fmt.Sprintf("%s%s", base, v)
}
return v
}