github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/pingfed/pingfed.go (about) 1 package pingfed 2 3 import ( 4 "context" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "net/url" 9 "time" 10 "encoding/base64" 11 12 "github.com/PuerkitoBio/goquery" 13 "github.com/pkg/errors" 14 "github.com/sirupsen/logrus" 15 "github.com/tidwall/gjson" 16 "github.com/versent/saml2aws/pkg/cfg" 17 "github.com/versent/saml2aws/pkg/creds" 18 "github.com/versent/saml2aws/pkg/page" 19 "github.com/versent/saml2aws/pkg/prompter" 20 "github.com/versent/saml2aws/pkg/provider" 21 ) 22 23 var logger = logrus.WithField("provider", "pingfed") 24 25 // Client wrapper around PingFed + PingId enabling authentication and retrieval of assertions 26 type Client struct { 27 client *provider.HTTPClient 28 idpAccount *cfg.IDPAccount 29 } 30 31 // New create a new PingFed client 32 func New(idpAccount *cfg.IDPAccount) (*Client, error) { 33 34 tr := provider.NewDefaultTransport(idpAccount.SkipVerify) 35 36 client, err := provider.NewHTTPClient(tr) 37 if err != nil { 38 return nil, errors.Wrap(err, "error building http client") 39 } 40 41 // assign a response validator to ensure all responses are either success or a redirect 42 // this is to avoid have explicit checks for every single response 43 client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator 44 45 return &Client{ 46 client: client, 47 idpAccount: idpAccount, 48 }, nil 49 } 50 51 type ctxKey string 52 53 // Authenticate Authenticate to PingFed and return the data from the body of the SAML assertion. 54 func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { 55 url := fmt.Sprintf("%s/idp/startSSO.ping?PartnerSpId=%s", loginDetails.URL, ac.idpAccount.AmazonWebservicesURN) 56 req, err := http.NewRequest("GET", url, nil) 57 if err != nil { 58 return "", errors.Wrap(err, "error building request") 59 } 60 ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) 61 return ac.follow(ctx, req) 62 } 63 64 func (ac *Client) follow(ctx context.Context, req *http.Request) (string, error) { 65 res, err := ac.client.Do(req) 66 if err != nil { 67 return "", errors.Wrap(err, "error following") 68 } 69 doc, err := goquery.NewDocumentFromResponse(res) 70 if err != nil { 71 return "", errors.Wrap(err, "failed to build document from response") 72 } 73 74 var handler func(context.Context, *goquery.Document) (context.Context, *http.Request, error) 75 76 if docIsFormRedirectToAWS(doc) { 77 logger.WithField("type", "saml-response-to-aws").Debug("doc detect") 78 if samlResponse, ok := extractSAMLResponse(doc); ok { 79 decodedSamlResponse, err := base64.StdEncoding.DecodeString(samlResponse) 80 if err != nil { 81 return "", errors.Wrap(err, "failed to decode saml-response") 82 } 83 logger.WithField("type", "saml-response").WithField("saml-response", string(decodedSamlResponse)).Debug("doc detect") 84 return samlResponse, nil 85 } 86 } else if docIsFormSamlRequest(doc) { 87 logger.WithField("type", "saml-request").Debug("doc detect") 88 handler = ac.handleFormRedirect 89 } else if docIsFormResume(doc) { 90 logger.WithField("type", "resume").Debug("doc detect") 91 handler = ac.handleFormRedirect 92 } else if docIsFormSamlResponse(doc) { 93 logger.WithField("type", "saml-response").Debug("doc detect") 94 handler = ac.handleFormRedirect 95 } else if docIsLogin(doc) { 96 logger.WithField("type", "login").Debug("doc detect") 97 handler = ac.handleLogin 98 } else if docIsOTP(doc) { 99 logger.WithField("type", "otp").Debug("doc detect") 100 handler = ac.handleOTP 101 } else if docIsSwipe(doc) { 102 logger.WithField("type", "swipe").Debug("doc detect") 103 handler = ac.handleSwipe 104 } else if docIsFormRedirect(doc) { 105 logger.WithField("type", "form-redirect").Debug("doc detect") 106 handler = ac.handleFormRedirect 107 } else if docIsWebAuthn(doc) { 108 logger.WithField("type", "webauthn").Debug("doc detect") 109 handler = ac.handleWebAuthn 110 } 111 if handler == nil { 112 html, _ := doc.Selection.Html() 113 logger.WithField("doc", html).Debug("Unknown document type") 114 return "", fmt.Errorf("Unknown document type") 115 } 116 117 ctx, req, err = handler(ctx, doc) 118 if err != nil { 119 return "", err 120 } 121 return ac.follow(ctx, req) 122 } 123 124 func (ac *Client) handleLogin(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) { 125 loginDetails, ok := ctx.Value(ctxKey("login")).(*creds.LoginDetails) 126 if !ok { 127 return ctx, nil, fmt.Errorf("no context value for 'login'") 128 } 129 130 form, err := page.NewFormFromDocument(doc, "form") 131 if err != nil { 132 return ctx, nil, errors.Wrap(err, "error extracting login form") 133 } 134 135 form.Values.Set("pf.username", loginDetails.Username) 136 form.Values.Set("pf.pass", loginDetails.Password) 137 form.URL = makeAbsoluteURL(form.URL, loginDetails.URL) 138 139 req, err := form.BuildRequest() 140 return ctx, req, err 141 } 142 143 func (ac *Client) handleOTP(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) { 144 form, err := page.NewFormFromDocument(doc, "#otp-form") 145 if err != nil { 146 return ctx, nil, errors.Wrap(err, "error extracting OTP form") 147 } 148 149 token := prompter.StringRequired("Enter passcode") 150 form.Values.Set("otp", token) 151 req, err := form.BuildRequest() 152 return ctx, req, err 153 } 154 155 func (ac *Client) handleSwipe(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) { 156 form, err := page.NewFormFromDocument(doc, "#form1") 157 if err != nil { 158 return ctx, nil, errors.Wrap(err, "error extracting swipe status form") 159 } 160 161 // poll status. request must specifically be a GET 162 form.Method = "GET" 163 req, err := form.BuildRequest() 164 if err != nil { 165 return ctx, nil, err 166 } 167 168 for { 169 time.Sleep(3 * time.Second) 170 171 res, err := ac.client.Do(req) 172 if err != nil { 173 return ctx, nil, errors.Wrap(err, "error polling swipe status") 174 } 175 176 body, err := ioutil.ReadAll(res.Body) 177 if err != nil { 178 return ctx, nil, errors.Wrap(err, "error parsing body from swipe status response") 179 } 180 181 resp := string(body) 182 183 pingfedMFAStatusResponse := gjson.Get(resp, "status").String() 184 185 //ASYNC_AUTH_WAIT indicates we keep going 186 //OK indicates someone swiped 187 //DEVICE_CLAIM_TIMEOUT indicates nobody swiped 188 //otherwise loop forever? 189 190 if pingfedMFAStatusResponse == "OK" || pingfedMFAStatusResponse == "DEVICE_CLAIM_TIMEOUT" || pingfedMFAStatusResponse == "TIMEOUT" { 191 break 192 } 193 } 194 195 // now build a request for getting response of MFA 196 form, err = page.NewFormFromDocument(doc, "#reponseView") 197 if err != nil { 198 return ctx, nil, errors.Wrap(err, "error extracting swipe response form") 199 } 200 req, err = form.BuildRequest() 201 return ctx, req, err 202 } 203 204 func (ac *Client) handleFormRedirect(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) { 205 form, err := page.NewFormFromDocument(doc, "") 206 if err != nil { 207 return ctx, nil, errors.Wrap(err, "error extracting redirect form") 208 } 209 req, err := form.BuildRequest() 210 return ctx, req, err 211 } 212 213 func (ac *Client) handleWebAuthn(ctx context.Context, doc *goquery.Document) (context.Context, *http.Request, error) { 214 form, err := page.NewFormFromDocument(doc, "") 215 if err != nil { 216 return ctx, nil, errors.Wrap(err, "error extracting webauthn form") 217 } 218 form.Values.Set("isWebAuthnSupportedByBrowser", "false") 219 req, err := form.BuildRequest() 220 return ctx, req, err 221 } 222 223 func docIsLogin(doc *goquery.Document) bool { 224 return doc.Has("input[name=\"pf.pass\"]").Size() == 1 225 } 226 227 func docIsOTP(doc *goquery.Document) bool { 228 return doc.Has("form#otp-form").Size() == 1 229 } 230 231 func docIsSwipe(doc *goquery.Document) bool { 232 return doc.Has("form#form1").Size() == 1 && doc.Has("form#reponseView").Size() == 1 233 } 234 235 func docIsFormRedirect(doc *goquery.Document) bool { 236 return doc.Has("input[name=\"ppm_request\"]").Size() == 1 237 } 238 239 func docIsWebAuthn(doc *goquery.Document) bool { 240 return doc.Has("input[name=\"isWebAuthnSupportedByBrowser\"]").Size() == 1 241 } 242 243 func docIsFormSamlRequest(doc *goquery.Document) bool { 244 return doc.Find("input[name=\"SAMLRequest\"]").Size() == 1 245 } 246 247 func docIsFormSamlResponse(doc *goquery.Document) bool { 248 return doc.Find("input[name=\"SAMLResponse\"]").Size() == 1 249 } 250 251 func docIsFormResume(doc *goquery.Document) bool { 252 return doc.Find("input[name=\"RelayState\"]").Size() == 1 253 } 254 255 func docIsFormRedirectToAWS(doc *goquery.Document) bool { 256 return doc.Find("form[action=\"https://signin.aws.amazon.com/saml\"]").Size() == 1 257 } 258 259 func extractSAMLResponse(doc *goquery.Document) (v string, ok bool) { 260 return doc.Find("input[name=\"SAMLResponse\"]").Attr("value") 261 } 262 263 // ensures given url is an absolute URL. if not, it will be combined with the base URL 264 func makeAbsoluteURL(v string, base string) string { 265 if u, err := url.ParseRequestURI(v); err == nil && !u.IsAbs() { 266 return fmt.Sprintf("%s%s", base, v) 267 } 268 return v 269 }