github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/pingfed/pingfed.go (about)

     1  package pingfed
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/url"
     9  	"time"
    10  	"encoding/base64"
    11  
    12  	"github.com/PuerkitoBio/goquery"
    13  	"github.com/pkg/errors"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/tidwall/gjson"
    16  	"github.com/versent/saml2aws/pkg/cfg"
    17  	"github.com/versent/saml2aws/pkg/creds"
    18  	"github.com/versent/saml2aws/pkg/page"
    19  	"github.com/versent/saml2aws/pkg/prompter"
    20  	"github.com/versent/saml2aws/pkg/provider"
    21  )
    22  
    23  var logger = logrus.WithField("provider", "pingfed")
    24  
    25  // Client wrapper around PingFed + PingId enabling authentication and retrieval of assertions
    26  type Client struct {
    27  	client     *provider.HTTPClient
    28  	idpAccount *cfg.IDPAccount
    29  }
    30  
    31  // New create a new PingFed client
    32  func New(idpAccount *cfg.IDPAccount) (*Client, error) {
    33  
    34  	tr := provider.NewDefaultTransport(idpAccount.SkipVerify)
    35  
    36  	client, err := provider.NewHTTPClient(tr)
    37  	if err != nil {
    38  		return nil, errors.Wrap(err, "error building http client")
    39  	}
    40  
    41  	// assign a response validator to ensure all responses are either success or a redirect
    42  	// this is to avoid have explicit checks for every single response
    43  	client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator
    44  
    45  	return &Client{
    46  		client:     client,
    47  		idpAccount: idpAccount,
    48  	}, nil
    49  }
    50  
    51  type ctxKey string
    52  
    53  // Authenticate Authenticate to PingFed and return the data from the body of the SAML assertion.
    54  func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
    55  	url := fmt.Sprintf("%s/idp/startSSO.ping?PartnerSpId=%s", loginDetails.URL, ac.idpAccount.AmazonWebservicesURN)
    56  	req, err := http.NewRequest("GET", url, nil)
    57  	if err != nil {
    58  		return "", errors.Wrap(err, "error building request")
    59  	}
    60  	ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails)
    61  	return ac.follow(ctx, req)
    62  }
    63  
    64  func (ac *Client) follow(ctx context.Context, req *http.Request) (string, error) {
    65  	res, err := ac.client.Do(req)
    66  	if err != nil {
    67  		return "", errors.Wrap(err, "error following")
    68  	}
    69  	doc, err := goquery.NewDocumentFromResponse(res)
    70  	if err != nil {
    71  		return "", errors.Wrap(err, "failed to build document from response")
    72  	}
    73  
    74  	var handler func(context.Context, *goquery.Document) (context.Context, *http.Request, error)
    75  	
    76  	if docIsFormRedirectToAWS(doc) {
    77  		logger.WithField("type", "saml-response-to-aws").Debug("doc detect")
    78  		if samlResponse, ok := extractSAMLResponse(doc); ok {
    79  			decodedSamlResponse, err := base64.StdEncoding.DecodeString(samlResponse)
    80  			if err != nil {
    81  				return "", errors.Wrap(err, "failed to decode saml-response")
    82  			}
    83  			logger.WithField("type", "saml-response").WithField("saml-response", string(decodedSamlResponse)).Debug("doc detect")
    84  			return samlResponse, nil
    85  		}
    86  	} else if docIsFormSamlRequest(doc) {
    87  		logger.WithField("type", "saml-request").Debug("doc detect")
    88  		handler = ac.handleFormRedirect
    89  	} else if docIsFormResume(doc) {
    90  		logger.WithField("type", "resume").Debug("doc detect")
    91  		handler = ac.handleFormRedirect
    92  	} else if docIsFormSamlResponse(doc) {
    93  		logger.WithField("type", "saml-response").Debug("doc detect")
    94  		handler = ac.handleFormRedirect
    95  	} else if docIsLogin(doc) {
    96  		logger.WithField("type", "login").Debug("doc detect")
    97  		handler = ac.handleLogin
    98  	} else if docIsOTP(doc) {
    99  		logger.WithField("type", "otp").Debug("doc detect")
   100  		handler = ac.handleOTP
   101  	} else if docIsSwipe(doc) {
   102  		logger.WithField("type", "swipe").Debug("doc detect")
   103  		handler = ac.handleSwipe
   104  	} else if docIsFormRedirect(doc) {
   105  		logger.WithField("type", "form-redirect").Debug("doc detect")
   106  		handler = ac.handleFormRedirect
   107  	} else if docIsWebAuthn(doc) {
   108  		logger.WithField("type", "webauthn").Debug("doc detect")
   109  		handler = ac.handleWebAuthn
   110  	}
   111  	if handler == nil {
   112  		html, _ := doc.Selection.Html()
   113  		logger.WithField("doc", html).Debug("Unknown document type")
   114  		return "", fmt.Errorf("Unknown document type")
   115  	}
   116  
   117  	ctx, req, err = handler(ctx, doc)
   118  	if err != nil {
   119  		return "", err
   120  	}
   121  	return ac.follow(ctx, req)
   122  }
   123  
   124  func (ac *Client) handleLogin(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
   125  	loginDetails, ok := ctx.Value(ctxKey("login")).(*creds.LoginDetails)
   126  	if !ok {
   127  		return ctx, nil, fmt.Errorf("no context value for 'login'")
   128  	}
   129  
   130  	form, err := page.NewFormFromDocument(doc, "form")
   131  	if err != nil {
   132  		return ctx, nil, errors.Wrap(err, "error extracting login form")
   133  	}
   134  
   135  	form.Values.Set("pf.username", loginDetails.Username)
   136  	form.Values.Set("pf.pass", loginDetails.Password)
   137  	form.URL = makeAbsoluteURL(form.URL, loginDetails.URL)
   138  
   139  	req, err := form.BuildRequest()
   140  	return ctx, req, err
   141  }
   142  
   143  func (ac *Client) handleOTP(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
   144  	form, err := page.NewFormFromDocument(doc, "#otp-form")
   145  	if err != nil {
   146  		return ctx, nil, errors.Wrap(err, "error extracting OTP form")
   147  	}
   148  
   149  	token := prompter.StringRequired("Enter passcode")
   150  	form.Values.Set("otp", token)
   151  	req, err := form.BuildRequest()
   152  	return ctx, req, err
   153  }
   154  
   155  func (ac *Client) handleSwipe(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
   156  	form, err := page.NewFormFromDocument(doc, "#form1")
   157  	if err != nil {
   158  		return ctx, nil, errors.Wrap(err, "error extracting swipe status form")
   159  	}
   160  
   161  	// poll status. request must specifically be a GET
   162  	form.Method = "GET"
   163  	req, err := form.BuildRequest()
   164  	if err != nil {
   165  		return ctx, nil, err
   166  	}
   167  
   168  	for {
   169  		time.Sleep(3 * time.Second)
   170  
   171  		res, err := ac.client.Do(req)
   172  		if err != nil {
   173  			return ctx, nil, errors.Wrap(err, "error polling swipe status")
   174  		}
   175  
   176  		body, err := ioutil.ReadAll(res.Body)
   177  		if err != nil {
   178  			return ctx, nil, errors.Wrap(err, "error parsing body from swipe status response")
   179  		}
   180  
   181  		resp := string(body)
   182  
   183  		pingfedMFAStatusResponse := gjson.Get(resp, "status").String()
   184  
   185  		//ASYNC_AUTH_WAIT indicates we keep going
   186  		//OK indicates someone swiped
   187  		//DEVICE_CLAIM_TIMEOUT indicates nobody swiped
   188  		//otherwise loop forever?
   189  
   190  		if pingfedMFAStatusResponse == "OK" || pingfedMFAStatusResponse == "DEVICE_CLAIM_TIMEOUT" || pingfedMFAStatusResponse == "TIMEOUT" {
   191  			break
   192  		}
   193  	}
   194  
   195  	// now build a request for getting response of MFA
   196  	form, err = page.NewFormFromDocument(doc, "#reponseView")
   197  	if err != nil {
   198  		return ctx, nil, errors.Wrap(err, "error extracting swipe response form")
   199  	}
   200  	req, err = form.BuildRequest()
   201  	return ctx, req, err
   202  }
   203  
   204  func (ac *Client) handleFormRedirect(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
   205  	form, err := page.NewFormFromDocument(doc, "")
   206  	if err != nil {
   207  		return ctx, nil, errors.Wrap(err, "error extracting redirect form")
   208  	}
   209  	req, err := form.BuildRequest()
   210  	return ctx, req, err
   211  }
   212  
   213  func (ac *Client) handleWebAuthn(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) {
   214  	form, err := page.NewFormFromDocument(doc, "")
   215  	if err != nil {
   216  		return ctx, nil, errors.Wrap(err, "error extracting webauthn form")
   217  	}
   218  	form.Values.Set("isWebAuthnSupportedByBrowser", "false")
   219  	req, err := form.BuildRequest()
   220  	return ctx, req, err
   221  }
   222  
   223  func docIsLogin(doc *goquery.Document) bool {
   224  	return doc.Has("input[name=\"pf.pass\"]").Size() == 1
   225  }
   226  
   227  func docIsOTP(doc *goquery.Document) bool {
   228  	return doc.Has("form#otp-form").Size() == 1
   229  }
   230  
   231  func docIsSwipe(doc *goquery.Document) bool {
   232  	return doc.Has("form#form1").Size() == 1 && doc.Has("form#reponseView").Size() == 1
   233  }
   234  
   235  func docIsFormRedirect(doc *goquery.Document) bool {
   236  	return doc.Has("input[name=\"ppm_request\"]").Size() == 1
   237  }
   238  
   239  func docIsWebAuthn(doc *goquery.Document) bool {
   240  	return doc.Has("input[name=\"isWebAuthnSupportedByBrowser\"]").Size() == 1
   241  }
   242  
   243  func docIsFormSamlRequest(doc *goquery.Document) bool {
   244  	return doc.Find("input[name=\"SAMLRequest\"]").Size() == 1
   245  }
   246  
   247  func docIsFormSamlResponse(doc *goquery.Document) bool {
   248  	return doc.Find("input[name=\"SAMLResponse\"]").Size() == 1
   249  }
   250  
   251  func docIsFormResume(doc *goquery.Document) bool {
   252  	return doc.Find("input[name=\"RelayState\"]").Size() == 1
   253  }
   254  
   255  func docIsFormRedirectToAWS(doc *goquery.Document) bool {
   256  	return doc.Find("form[action=\"https://signin.aws.amazon.com/saml\"]").Size() == 1
   257  }
   258  
   259  func extractSAMLResponse(doc *goquery.Document) (v string, ok bool) {
   260  	return doc.Find("input[name=\"SAMLResponse\"]").Attr("value")
   261  }
   262  
   263  // ensures given url is an absolute URL. if not, it will be combined with the base URL
   264  func makeAbsoluteURL(v string, base string) string {
   265  	if u, err := url.ParseRequestURI(v); err == nil && !u.IsAbs() {
   266  		return fmt.Sprintf("%s%s", base, v)
   267  	}
   268  	return v
   269  }