pkg/provider/adfs/adfs.go (232 lines of code) (raw):
package adfs
import (
"crypto/tls"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/aliyun/saml2alibabacloud/pkg/cfg"
"github.com/aliyun/saml2alibabacloud/pkg/creds"
"github.com/aliyun/saml2alibabacloud/pkg/prompter"
"github.com/aliyun/saml2alibabacloud/pkg/provider"
"github.com/pkg/errors"
)
// Client wrapper around ADFS enabling authentication and retrieval of assertions
type Client struct {
client *provider.HTTPClient
idpAccount *cfg.IDPAccount
}
type AuthResponseType int
const (
UNKNOWN AuthResponseType = iota
SAML_RESPONSE
MFA_PROMPT
AZURE_MFA_WAIT
AZURE_MFA_SERVER_WAIT
)
// New create a new ADFS client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: idpAccount.SkipVerify, Renegotiation: tls.RenegotiateFreelyAsClient},
}
client, err := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(idpAccount))
if err != nil {
return nil, errors.Wrap(err, "error building http client")
}
return &Client{
client: client,
idpAccount: idpAccount,
}, nil
}
// Authenticate to ADFS and return the data from the body of the SAML assertion.
func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
var authSubmitURL string
var samlAssertion string
var instructions string
alibabacloudURN := url.QueryEscape(ac.idpAccount.AlibabaCloudURN)
adfsURL := fmt.Sprintf("%s/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=%s", loginDetails.URL, alibabacloudURN)
mfaToken := loginDetails.MFAToken
doc, err := ac.get(adfsURL)
if err != nil {
return "", errors.Wrap(err, "failed to get adfs page")
}
authForm := url.Values{}
doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateFormData(authForm, s, loginDetails)
})
doc.Find("form").Each(func(i int, s *goquery.Selection) {
action, ok := s.Attr("action")
if !ok {
return
}
authSubmitURL = action
})
if authSubmitURL == "" {
return samlAssertion, fmt.Errorf("unable to locate IDP authentication form submit URL")
}
doc, err = ac.submit(authSubmitURL, authForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "failed to submit adfs auth form")
}
for {
responseType, samlAssertion, err := checkResponse(doc)
switch responseType {
case SAML_RESPONSE:
return samlAssertion, err
case MFA_PROMPT:
otpForm := url.Values{}
if mfaToken == "" {
mfaToken = prompter.RequestSecurityCode("000000")
}
doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateOTPFormData(otpForm, s, mfaToken)
})
doc, err = ac.submit(authSubmitURL, otpForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
mfaToken = ""
case AZURE_MFA_SERVER_WAIT:
fallthrough
case AZURE_MFA_WAIT:
azureForm := url.Values{}
doc.Find("input").Each(func(i int, s *goquery.Selection) {
updatePassthroughFormData(azureForm, s)
})
sel := doc.Find("p#instructions")
if sel.Index() != -1 {
if instructions != sel.Text() {
instructions = sel.Text()
log.Println(instructions)
}
}
time.Sleep(1 * time.Second)
doc, err = ac.submit(authSubmitURL, azureForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
if responseType == AZURE_MFA_SERVER_WAIT {
sel := doc.Find("label#errorText")
if sel.Index() != -1 {
return samlAssertion, errors.New(sel.Text())
}
}
case UNKNOWN:
return samlAssertion, errors.New("unable to classify response from auth server")
}
}
}
func (ac *Client) get(url string) (*goquery.Document, error) {
res, err := ac.client.Get(url)
if err != nil {
return nil, errors.Wrap(err, "error retieving form")
}
defer res.Body.Close()
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}
func (ac *Client) submit(url string, form url.Values) (*goquery.Document, error) {
req, err := http.NewRequest("POST", url, strings.NewReader(form.Encode()))
if err != nil {
return nil, errors.Wrap(err, "error building request")
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := ac.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "error submitting form")
}
defer res.Body.Close()
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}
func checkResponse(doc *goquery.Document) (AuthResponseType, string, error) {
samlAssertion := ""
responseType := UNKNOWN
doc.Find("input").Each(func(i int, s *goquery.Selection) {
name, ok := s.Attr("name")
if !ok {
log.Fatalf("unable to locate IDP authentication form submit URL")
}
if name == "SAMLResponse" {
val, ok := s.Attr("value")
if !ok {
log.Fatalf("unable to locate saml assertion value")
}
samlAssertion = val
responseType = SAML_RESPONSE
}
if name == "AuthMethod" {
val, _ := s.Attr("value")
switch val {
case "VIPAuthenticationProviderWindowsAccountName", "VIPAuthenticationProviderUPN", "Defender AD FS Adapter":
responseType = MFA_PROMPT
case "AzureMfaAuthentication":
responseType = AZURE_MFA_WAIT
case "AzureMfaServerAuthentication":
responseType = AZURE_MFA_SERVER_WAIT
}
}
if name == "VerificationCode" {
responseType = MFA_PROMPT
}
})
return responseType, samlAssertion, nil
}
func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.LoginDetails) {
name, ok := s.Attr("name")
if !ok {
return
}
typeValue, typeFound := s.Attr("type")
hiddenAttr := typeFound && typeValue == "hidden"
lname := strings.ToLower(name)
if strings.Contains(lname, "user") {
if !hiddenAttr {
authForm.Add(name, user.Username)
}
} else if strings.Contains(lname, "email") {
if !hiddenAttr {
authForm.Add(name, user.Username)
}
} else if strings.Contains(lname, "pass") {
if !hiddenAttr {
authForm.Add(name, user.Password)
}
} else {
updatePassthroughFormData(authForm, s)
}
}
func updateOTPFormData(otpForm url.Values, s *goquery.Selection, token string) {
name, ok := s.Attr("name")
if !ok {
return
}
lname := strings.ToLower(name)
if strings.Contains(lname, "security_code") {
otpForm.Add(name, token)
} else if strings.Contains(lname, "verificationcode") {
otpForm.Add(name, token)
} else if strings.Contains(lname, "challengequestionanswer") {
otpForm.Add(name, token)
} else {
updatePassthroughFormData(otpForm, s)
}
}
func updatePassthroughFormData(otpForm url.Values, s *goquery.Selection) {
name, ok := s.Attr("name")
if !ok {
return
}
val, ok := s.Attr("value")
if !ok {
return
}
otpForm.Add(name, val)
}