github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/pingone/pingone.go (about)

     1  package pingone
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/url"
     9  	"time"
    10  	"encoding/base64"
    11  
    12  	"github.com/PuerkitoBio/goquery"
    13  	"github.com/pkg/errors"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/tidwall/gjson"
    16  	"github.com/versent/saml2aws/pkg/cfg"
    17  	"github.com/versent/saml2aws/pkg/creds"
    18  	"github.com/versent/saml2aws/pkg/page"
    19  	"github.com/versent/saml2aws/pkg/prompter"
    20  	"github.com/versent/saml2aws/pkg/provider"
    21  )
    22  
    23  var logger = logrus.WithField("provider", "pingone")
    24  
    25  // Client wrapper around PingOne + PingId enabling authentication and retrieval of assertions
    26  type Client struct {
    27  	client     *provider.HTTPClient
    28  	idpAccount *cfg.IDPAccount
    29  }
    30  
    31  // New create a new PingOne client
    32  func New(idpAccount *cfg.IDPAccount) (*Client, error) {
    33  
    34  	tr := provider.NewDefaultTransport(idpAccount.SkipVerify)
    35  
    36  	client, err := provider.NewHTTPClient(tr)
    37  	if err != nil {
    38  		return nil, errors.Wrap(err, "error building http client")
    39  	}
    40  
    41  	// assign a response validator to ensure all responses are either success or a redirect
    42  	// this is to avoid have explicit checks for every single response
    43  	client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator
    44  
    45  	return &Client{
    46  		client:     client,
    47  		idpAccount: idpAccount,
    48  	}, nil
    49  }
    50  
    51  type ctxKey string
    52  
    53  // Authenticate Authenticate to PingOne and return the data from the body of the SAML assertion.
    54  func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
    55  	req, err := http.NewRequest("GET", loginDetails.URL, nil)
    56  	if err != nil {
    57  		return "", errors.Wrap(err, "error building request")
    58  	}
    59  	ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails)
    60  	return ac.follow(ctx, req)
    61  }
    62  
    63  func (ac *Client) follow(ctx context.Context, req *http.Request) (string, error) {
    64  	res, err := ac.client.Do(req)
    65  	if err != nil {
    66  		return "", errors.Wrap(err, "error following")
    67  	}
    68  
    69  	doc, err := goquery.NewDocumentFromResponse(res)
    70  	if err != nil {
    71  		return "", errors.Wrap(err, "failed to build document from response")
    72  	}
    73  
    74  	var handler func(context.Context, *goquery.Document, *http.Response) (context.Context, *http.Request, error)
    75  
    76  	if docIsFormRedirectToAWS(doc) {
    77  		logger.WithField("type", "saml-response-to-aws").Debug("doc detect")
    78  		if samlResponse, ok := extractSAMLResponse(doc); ok {
    79  			decodedSamlResponse, err := base64.StdEncoding.DecodeString(samlResponse)
    80  			if err != nil {
    81  				return "", errors.Wrap(err, "failed to decode saml-response")
    82  			}
    83  			logger.WithField("type", "saml-response").WithField("saml-response", string(decodedSamlResponse)).Debug("doc detect")
    84  			return samlResponse, nil
    85  		}
    86  	} else if docIsFormSamlRequest(doc) {
    87  		logger.WithField("type", "saml-request").Debug("doc detect")
    88  		handler = ac.handleFormRedirect
    89  	} else if docIsFormResume(doc) {
    90  		logger.WithField("type", "resume").Debug("doc detect")
    91  		handler = ac.handleFormRedirect
    92  	} else if docIsLogin(doc) {
    93  		logger.WithField("type", "login").Debug("doc detect")
    94  		handler = ac.handleLogin
    95  	} else if docIsCheckWebAuthn(doc) {
    96  		logger.WithField("type", "check-webauthn").Debug("doc detect")
    97  		handler = ac.handleCheckWebAuthn
    98  	} else if docIsFormSelectDevice(doc) {
    99  		logger.WithField("type", "select-device").Debug("doc detect")
   100  		handler = ac.handleFormSelectDevice
   101  	} else if docIsOTP(doc) {
   102  		logger.WithField("type", "otp").Debug("doc detect")
   103  		handler = ac.handleOTP
   104  	} else if docIsSwipe(doc) {
   105  		logger.WithField("type", "swipe").Debug("doc detect")
   106  		handler = ac.handleSwipe
   107  	} else if docIsFormRedirect(doc) {
   108  		logger.WithField("type", "form-redirect").Debug("doc detect")
   109  		handler = ac.handleFormRedirect
   110  	}
   111  	if handler == nil {
   112  		html, _ := doc.Selection.Html()
   113  		logger.WithField("doc", html).Debug("Unknown document type")
   114  		return "", fmt.Errorf("Unknown document type")
   115  	}
   116  
   117  	ctx, req, err = handler(ctx, doc, res)
   118  	if err != nil {
   119  		return "", err
   120  	}
   121  	return ac.follow(ctx, req)
   122  }
   123  
   124  func (ac *Client) handleLogin(ctx context.Context, doc *goquery.Document, res *http.Response) (context.Context, *http.Request, error) {
   125  	loginDetails, ok := ctx.Value(ctxKey("login")).(*creds.LoginDetails)
   126  	if !ok {
   127  		return ctx, nil, fmt.Errorf("no context value for 'login'")
   128  	}
   129  
   130  	form, err := page.NewFormFromDocument(doc, "form")
   131  	if err != nil {
   132  		return ctx, nil, errors.Wrap(err, "error extracting login form")
   133  	}
   134  
   135  	baseURL := makeBaseURL(res.Request.URL)
   136  	logger.WithField("baseURL", baseURL).Debug("base url")
   137  
   138  	form.Values.Set("pf.username", loginDetails.Username)
   139  	form.Values.Set("pf.pass", loginDetails.Password)
   140  	form.URL, err = makeAbsoluteURL(form.URL, baseURL)
   141  	if err != nil {
   142  		return ctx, nil, err
   143  	}
   144  
   145  	req, err := form.BuildRequest()
   146  	return ctx, req, err
   147  }
   148  
   149  func (ac *Client) handleCheckWebAuthn(ctx context.Context, doc *goquery.Document, res *http.Response) (context.Context, *http.Request, error) {
   150  	form, err := page.NewFormFromDocument(doc, "form")
   151  	if err != nil {
   152  		return ctx, nil, errors.Wrap(err, "error extracting login form")
   153  	}
   154  
   155  	form.Values.Set("isWebAuthnSupportedByBrowser", "false")
   156  
   157  	req, err := form.BuildRequest()
   158  	return ctx, req, err
   159  }
   160  
   161  func (ac *Client) handleOTP(ctx context.Context, doc *goquery.Document, _ *http.Response) (context.Context, *http.Request, error) {
   162  	form, err := page.NewFormFromDocument(doc, "#otp-form")
   163  	if err != nil {
   164  		return ctx, nil, errors.Wrap(err, "error extracting OTP form")
   165  	}
   166  
   167  	token := prompter.StringRequired("Enter passcode")
   168  	form.Values.Set("otp", token)
   169  	req, err := form.BuildRequest()
   170  	return ctx, req, err
   171  }
   172  
   173  func (ac *Client) handleSwipe(ctx context.Context, doc *goquery.Document, _ *http.Response) (context.Context, *http.Request, error) {
   174  	form, err := page.NewFormFromDocument(doc, "#form1")
   175  	if err != nil {
   176  		return ctx, nil, errors.Wrap(err, "error extracting swipe status form")
   177  	}
   178  
   179  	// poll status. request must specifically be a GET
   180  	form.Method = "GET"
   181  	req, err := form.BuildRequest()
   182  	if err != nil {
   183  		return ctx, nil, err
   184  	}
   185  
   186  	for {
   187  		time.Sleep(3 * time.Second)
   188  
   189  		res, err := ac.client.Do(req)
   190  		if err != nil {
   191  			return ctx, nil, errors.Wrap(err, "error polling swipe status")
   192  		}
   193  
   194  		body, err := ioutil.ReadAll(res.Body)
   195  		if err != nil {
   196  			return ctx, nil, errors.Wrap(err, "error parsing body from swipe status response")
   197  		}
   198  
   199  		resp := string(body)
   200  
   201  		pingfedMFAStatusResponse := gjson.Get(resp, "status").String()
   202  
   203  		//ASYNC_AUTH_WAIT indicates we keep going
   204  		//OK indicates someone swiped
   205  		//DEVICE_CLAIM_TIMEOUT indicates nobody swiped
   206  		//otherwise loop forever?
   207  
   208  		if pingfedMFAStatusResponse == "OK" || pingfedMFAStatusResponse == "DEVICE_CLAIM_TIMEOUT" || pingfedMFAStatusResponse == "TIMEOUT" {
   209  			break
   210  		}
   211  	}
   212  
   213  	// now build a request for getting response of MFA
   214  	form, err = page.NewFormFromDocument(doc, "#reponseView")
   215  	if err != nil {
   216  		return ctx, nil, errors.Wrap(err, "error extracting swipe response form")
   217  	}
   218  	req, err = form.BuildRequest()
   219  	return ctx, req, err
   220  }
   221  
   222  func (ac *Client) handleFormRedirect(ctx context.Context, doc *goquery.Document, _ *http.Response) (context.Context, *http.Request, error) {
   223  	form, err := page.NewFormFromDocument(doc, "")
   224  	if err != nil {
   225  		return ctx, nil, errors.Wrap(err, "error extracting redirect form")
   226  	}
   227  	req, err := form.BuildRequest()
   228  	return ctx, req, err
   229  }
   230  
   231  func (ac *Client) handleFormSamlRequest(ctx context.Context, doc *goquery.Document, _ *http.Response) (context.Context, *http.Request, error) {
   232  	form, err := page.NewFormFromDocument(doc, "")
   233  	if err != nil {
   234  		return ctx, nil, errors.Wrap(err, "error extracting samlrequest form")
   235  	}
   236  	req, err := form.BuildRequest()
   237  	return ctx, req, err
   238  }
   239  
   240  func (ac *Client) handleFormSelectDevice(ctx context.Context, doc *goquery.Document, res *http.Response) (context.Context, *http.Request, error) {
   241  	deviceList := make(map[string]string)
   242  	var deviceNameList []string
   243  
   244  	doc.Find("ul.device-list > li").Each(func(_ int, s *goquery.Selection) {
   245  		deviceId, _ := s.Attr("data-id")
   246  		deviceName, _ := s.Find("a > div.device-name").Html()
   247  
   248  		logger.WithField("device name", deviceName).WithField("device id", deviceId).Debug("Select Device")
   249  		deviceList[deviceName] = deviceId
   250  		deviceNameList = append(deviceNameList, deviceName)
   251  	})
   252  
   253  	var chooseDevice = prompter.Choose("Select which MFA Device to use", deviceNameList)
   254  
   255  	form, err := page.NewFormFromDocument(doc, "")
   256  	if err != nil {
   257  		return ctx, nil, errors.Wrap(err, "error extracting select device form")
   258  	}
   259  
   260  	form.Values.Set("deviceId", deviceList[deviceNameList[chooseDevice]])
   261  	form.URL, err = makeAbsoluteURL(form.URL, makeBaseURL(res.Request.URL))
   262  	if err != nil {
   263  		return ctx, nil, err
   264  	}
   265  
   266  	logger.WithField("value", form.Values.Encode()).Debug("Select Device")
   267  	req, err := form.BuildRequest()
   268  	return ctx, req, err
   269  }
   270  
   271  func docIsLogin(doc *goquery.Document) bool {
   272  	return doc.Has("input[name=\"pf.pass\"]").Size() == 1
   273  }
   274  
   275  func docIsOTP(doc *goquery.Document) bool {
   276  	return doc.Has("form#otp-form").Size() == 1
   277  }
   278  
   279  func docIsCheckWebAuthn(doc *goquery.Document) bool {
   280  	return doc.Has("input[name=\"isWebAuthnSupportedByBrowser\"]").Size() == 1
   281  }
   282  
   283  func docIsSwipe(doc *goquery.Document) bool {
   284  	return doc.Has("form#form1").Size() == 1 && doc.Has("form#reponseView").Size() == 1
   285  }
   286  
   287  func docIsFormRedirect(doc *goquery.Document) bool {
   288  	return doc.Has("input[name=\"ppm_request\"]").Size() == 1 || doc.Find("form[action=\"https://authenticator.pingone.com/pingid/ppm/auth\"]").Size() == 1
   289  }
   290  
   291  func docIsFormSamlRequest(doc *goquery.Document) bool {
   292  	return doc.Find("input[name=\"SAMLRequest\"]").Size() == 1
   293  }
   294  
   295  func docIsFormResume(doc *goquery.Document) bool {
   296  	return doc.Find("input[name=\"RelayState\"]").Size() == 1 || doc.Find("input[name=\"Resume\"]").Size() == 1
   297  }
   298  
   299  func docIsFormRedirectToAWS(doc *goquery.Document) bool {
   300  	return doc.Find("form[action=\"https://signin.aws.amazon.com/saml\"]").Size() == 1
   301  }
   302  
   303  func docIsFormSelectDevice(doc *goquery.Document) bool {
   304  	return doc.Has("form[name=\"device-form\"]").Size() == 1
   305  }
   306  
   307  func extractSAMLResponse(doc *goquery.Document) (v string, ok bool) {
   308  
   309  	return doc.Find("input[name=\"SAMLResponse\"]").Attr("value")
   310  }
   311  
   312  func makeBaseURL(url *url.URL) string {
   313  	return url.Scheme + "://" + url.Hostname()
   314  }
   315  
   316  // ensures given url is an absolute URL. if not, it will be combined with the base URL
   317  func makeAbsoluteURL(v string, base string) (string, error) {
   318  	logger.WithField("base", base).WithField("v", v).Debug("make absolute url")
   319  	baseURL, err := url.Parse(base)
   320  	if err != nil {
   321  		return "", err
   322  	}
   323  	pathURL, err := url.ParseRequestURI(v)
   324  	if err != nil {
   325  		return "", err
   326  	}
   327  	if pathURL.IsAbs() {
   328  		return pathURL.String(), nil
   329  	}
   330  	return baseURL.ResolveReference(pathURL).String(), nil
   331  }