pkg/provider/custom/custom.go (69 lines of code) (raw):
package custom
import (
"encoding/base64"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/aliyun/saml2alibabacloud/pkg/cfg"
"github.com/aliyun/saml2alibabacloud/pkg/creds"
"github.com/aliyun/saml2alibabacloud/pkg/provider"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
var logger = logrus.WithField("provider", "custom")
// Client is a wrapper representing a custom SAML client
type Client struct {
client *provider.HTTPClient
mfa string
}
// New creates a new custom client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {
tr := provider.NewDefaultTransport(idpAccount.SkipVerify)
client, err := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(idpAccount))
if err != nil {
return nil, errors.Wrap(err, "error building http client")
}
// assign a response validator to ensure all responses are either success or a redirect
// this is to avoid have explicit checks for every single response
client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator
return &Client{
client: client,
// TODO currently not supported
mfa: idpAccount.MFA,
}, nil
}
// Authenticate using an API endpoint with username and password then returns a SAML response
func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
_, err := url.Parse(loginDetails.URL)
if err != nil {
return "", errors.Wrap(err, "error building login request URL")
}
//authenticate using x-www-form-urlencoded
authReq := url.Values{}
authReq.Set("username", loginDetails.Username)
authReq.Set("password", loginDetails.Password)
authBody := strings.NewReader(authReq.Encode())
req, err := http.NewRequest("POST", loginDetails.URL, authBody)
if err != nil {
return "", errors.Wrap(err, "error building authentication request")
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(authReq.Encode())))
res, err := oc.client.Do(req)
if err != nil {
return "", errors.Wrap(err, "error retrieving auth response")
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", errors.Wrap(err, "error retrieving body from response")
}
resp := string(body)
successResponse := gjson.Get(resp, "success").String()
samlResponse := gjson.Get(resp, "data").String()
// error response
if successResponse != "true" {
return "", errors.Wrap(err, "error retrieving SAML response")
}
decodedSamlResponse, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
return "", errors.Wrap(err, "failed to decode SAML response")
}
logger.WithField("type", "saml-response").WithField("saml-response", string(decodedSamlResponse)).Debug("custom auth response")
return samlResponse, nil
}