github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/adfs/adfs.go (about)

     1  package adfs
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"log"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  
    11  	"github.com/PuerkitoBio/goquery"
    12  	"github.com/pkg/errors"
    13  	"github.com/sirupsen/logrus"
    14  	"github.com/versent/saml2aws/pkg/cfg"
    15  	"github.com/versent/saml2aws/pkg/creds"
    16  	"github.com/versent/saml2aws/pkg/prompter"
    17  	"github.com/versent/saml2aws/pkg/provider"
    18  )
    19  
    20  var logger = logrus.WithField("provider", "adfs")
    21  
    22  // Client wrapper around ADFS enabling authentication and retrieval of assertions
    23  type Client struct {
    24  	client     *provider.HTTPClient
    25  	idpAccount *cfg.IDPAccount
    26  }
    27  
    28  // New create a new ADFS client
    29  func New(idpAccount *cfg.IDPAccount) (*Client, error) {
    30  
    31  	tr := &http.Transport{
    32  		Proxy:           http.ProxyFromEnvironment,
    33  		TLSClientConfig: &tls.Config{InsecureSkipVerify: idpAccount.SkipVerify, Renegotiation: tls.RenegotiateFreelyAsClient},
    34  	}
    35  
    36  	client, err := provider.NewHTTPClient(tr)
    37  	if err != nil {
    38  		return nil, errors.Wrap(err, "error building http client")
    39  	}
    40  
    41  	return &Client{
    42  		client:     client,
    43  		idpAccount: idpAccount,
    44  	}, nil
    45  }
    46  
    47  // Authenticate to ADFS and return the data from the body of the SAML assertion.
    48  func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
    49  
    50  	var authSubmitURL string
    51  	var samlAssertion string
    52  
    53  	awsURN := url.QueryEscape(ac.idpAccount.AmazonWebservicesURN)
    54  
    55  	adfsURL := fmt.Sprintf("%s/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=%s", loginDetails.URL, awsURN)
    56  
    57  	res, err := ac.client.Get(adfsURL)
    58  	if err != nil {
    59  		return samlAssertion, errors.Wrap(err, "error retrieving form")
    60  	}
    61  
    62  	doc, err := goquery.NewDocumentFromResponse(res)
    63  	if err != nil {
    64  		return samlAssertion, errors.Wrap(err, "failed to build document from response")
    65  	}
    66  
    67  	authForm := url.Values{}
    68  
    69  	doc.Find("input").Each(func(i int, s *goquery.Selection) {
    70  		updateFormData(authForm, s, loginDetails)
    71  	})
    72  
    73  	doc.Find("form").Each(func(i int, s *goquery.Selection) {
    74  		action, ok := s.Attr("action")
    75  		if !ok {
    76  			return
    77  		}
    78  		authSubmitURL = action
    79  	})
    80  
    81  	if authSubmitURL == "" {
    82  		return samlAssertion, fmt.Errorf("unable to locate IDP authentication form submit URL")
    83  	}
    84  
    85  	//log.Printf("id authentication url: %s", authSubmitURL)
    86  
    87  	req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(authForm.Encode()))
    88  	if err != nil {
    89  		return samlAssertion, errors.Wrap(err, "error building authentication request")
    90  	}
    91  
    92  	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
    93  
    94  	res, err = ac.client.Do(req)
    95  	if err != nil {
    96  		return samlAssertion, errors.Wrap(err, "error retrieving login form results")
    97  	}
    98  
    99  	switch ac.idpAccount.MFA {
   100  	case "VIP":
   101  		res, err = ac.vipMFA(authSubmitURL, loginDetails.MFAToken, res)
   102  		if err != nil {
   103  			return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
   104  		}
   105  	}
   106  
   107  	// just parse the response whether res is from the login form or MFA form
   108  	doc, err = goquery.NewDocumentFromResponse(res)
   109  	if err != nil {
   110  		return samlAssertion, errors.Wrap(err, "error retrieving login response body")
   111  	}
   112  
   113  	doc.Find("input").Each(func(i int, s *goquery.Selection) {
   114  		name, ok := s.Attr("name")
   115  		if !ok {
   116  			log.Fatalf("unable to locate IDP authentication form submit URL")
   117  		}
   118  		if name == "SAMLResponse" {
   119  			val, ok := s.Attr("value")
   120  			if !ok {
   121  				log.Fatalf("unable to locate saml assertion value")
   122  			}
   123  			samlAssertion = val
   124  		}
   125  	})
   126  
   127  	return samlAssertion, nil
   128  }
   129  
   130  // vipMFA when supplied with the the form response document attempt to extract the VIP mfa related field
   131  // then use that to trigger a submit of the MFA security token
   132  func (ac *Client) vipMFA(authSubmitURL string, mfaToken string, res *http.Response) (*http.Response, error) {
   133  
   134  	doc, err := goquery.NewDocumentFromResponse(res)
   135  	if err != nil {
   136  		return nil, errors.Wrap(err, "error retrieving saml response body")
   137  	}
   138  
   139  	otpForm := url.Values{}
   140  
   141  	vipIndex := doc.Find("input#authMethod[value=VIPAuthenticationProviderWindowsAccountName]").Index()
   142  
   143  	if vipIndex == -1 {
   144  		return res, nil // if we didn't find the MFA flag then just continue
   145  	}
   146  
   147  	if mfaToken == "" {
   148  		mfaToken = prompter.RequestSecurityCode("000000")
   149  	}
   150  
   151  	doc.Find("input").Each(func(i int, s *goquery.Selection) {
   152  		updateOTPFormData(otpForm, s, mfaToken)
   153  	})
   154  
   155  	doc.Find("form").Each(func(i int, s *goquery.Selection) {
   156  		action, ok := s.Attr("action")
   157  		if !ok {
   158  			return
   159  		}
   160  		authSubmitURL = action
   161  	})
   162  
   163  	if authSubmitURL == "" {
   164  		return nil, fmt.Errorf("unable to locate IDP MFA form submit URL")
   165  	}
   166  
   167  	req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(otpForm.Encode()))
   168  	if err != nil {
   169  		return nil, errors.Wrap(err, "error building MFA request")
   170  	}
   171  
   172  	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   173  
   174  	res, err = ac.client.Do(req)
   175  	if err != nil {
   176  		return nil, errors.Wrap(err, "error retrieving content")
   177  	}
   178  
   179  	return res, nil
   180  }
   181  
   182  func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.LoginDetails) {
   183  	name, ok := s.Attr("name")
   184  	//	log.Printf("name = %s ok = %v", name, ok)
   185  	if !ok {
   186  		return
   187  	}
   188  
   189  	typeValue, typeFound := s.Attr("type")
   190  	hiddenAttr := typeFound && typeValue == "hidden"
   191  
   192  	lname := strings.ToLower(name)
   193  	if strings.Contains(lname, "user") {
   194  		if !hiddenAttr {
   195  			authForm.Add(name, user.Username)
   196  		}
   197  	} else if strings.Contains(lname, "email") {
   198  		if !hiddenAttr {
   199  			authForm.Add(name, user.Username)
   200  		}
   201  	} else if strings.Contains(lname, "pass") {
   202  		if !hiddenAttr {
   203  			authForm.Add(name, user.Password)
   204  		}
   205  	} else {
   206  		// pass through any hidden fields
   207  		val, ok := s.Attr("value")
   208  		if !ok {
   209  			return
   210  		}
   211  		authForm.Add(name, val)
   212  	}
   213  }
   214  
   215  func updateOTPFormData(otpForm url.Values, s *goquery.Selection, token string) {
   216  	name, ok := s.Attr("name")
   217  	//	log.Printf("name = %s ok = %v", name, ok)
   218  	if !ok {
   219  		return
   220  	}
   221  	lname := strings.ToLower(name)
   222  	if strings.Contains(lname, "security_code") {
   223  		otpForm.Add(name, token)
   224  	} else {
   225  		// pass through any hidden fields
   226  		val, ok := s.Attr("value")
   227  		if !ok {
   228  			return
   229  		}
   230  		otpForm.Add(name, val)
   231  	}
   232  
   233  }