github.com/snowflakedb/gosnowflake@v1.9.0/authokta.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "encoding/json" 9 "fmt" 10 "html" 11 "io" 12 "net/http" 13 "net/url" 14 "strconv" 15 "time" 16 ) 17 18 type authOKTARequest struct { 19 Username string `json:"username"` 20 Password string `json:"password"` 21 } 22 23 type authOKTAResponse struct { 24 CookieToken string `json:"cookieToken"` 25 SessionToken string `json:"sessionToken"` 26 } 27 28 /* 29 authenticateBySAML authenticates a user by SAML 30 SAML Authentication 31 1. query GS to obtain IDP token and SSO url 32 2. IMPORTANT Client side validation: 33 validate both token url and sso url contains same prefix 34 (protocol + host + port) as the given authenticator url. 35 Explanation: 36 This provides a way for the user to 'authenticate' the IDP it is 37 sending his/her credentials to. Without such a check, the user could 38 be coerced to provide credentials to an IDP impersonator. 39 3. query IDP token url to authenticate and retrieve access token 40 4. given access token, query IDP URL snowflake app to get SAML response 41 5. IMPORTANT Client side validation: 42 validate the post back url come back with the SAML response 43 contains the same prefix as the Snowflake's server url, which is the 44 intended destination url to Snowflake. 45 46 Explanation: 47 48 This emulates the behavior of IDP initiated login flow in the user 49 browser where the IDP instructs the browser to POST the SAML 50 assertion to the specific SP endpoint. This is critical in 51 preventing a SAML assertion issued to one SP from being sent to 52 another SP. 53 */ 54 func authenticateBySAML( 55 ctx context.Context, 56 sr *snowflakeRestful, 57 oktaURL *url.URL, 58 application string, 59 account string, 60 user string, 61 password string, 62 ) (samlResponse []byte, err error) { 63 logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url") 64 headers := make(map[string]string) 65 headers[httpHeaderContentType] = headerContentTypeApplicationJSON 66 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 67 headers[httpHeaderUserAgent] = userAgent 68 69 clientEnvironment := authRequestClientEnvironment{ 70 Application: application, 71 Os: operatingSystem, 72 OsVersion: platform, 73 } 74 requestMain := authRequestData{ 75 ClientAppID: clientType, 76 ClientAppVersion: SnowflakeGoDriverVersion, 77 AccountName: account, 78 ClientEnvironment: clientEnvironment, 79 Authenticator: oktaURL.String(), 80 } 81 authRequest := authRequest{ 82 Data: requestMain, 83 } 84 params := &url.Values{} 85 jsonBody, err := json.Marshal(authRequest) 86 if err != nil { 87 return nil, err 88 } 89 logger.WithContext(ctx).Infof("PARAMS for Auth: %v, %v", params, sr) 90 respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout) 91 if err != nil { 92 return nil, err 93 } 94 if !respd.Success { 95 logger.Errorln("Authentication FAILED") 96 sr.TokenAccessor.SetTokens("", "", -1) 97 code, err := strconv.Atoi(respd.Code) 98 if err != nil { 99 code = -1 100 return nil, err 101 } 102 return nil, &SnowflakeError{ 103 Number: code, 104 SQLState: SQLStateConnectionRejected, 105 Message: respd.Message, 106 } 107 } 108 logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL") 109 var tokenURL *url.URL 110 var ssoURL *url.URL 111 if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { 112 return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) 113 } 114 if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil { 115 return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL) 116 } 117 if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { 118 return nil, &SnowflakeError{ 119 Number: ErrCodeIdpConnectionError, 120 SQLState: SQLStateConnectionRejected, 121 Message: errMsgIdpConnectionError, 122 MessageArgs: []interface{}{oktaURL, respd.Data.TokenURL, respd.Data.SSOURL}, 123 } 124 } 125 logger.WithContext(ctx).Info("step 3: query IDP token url to authenticate and retrieve access token") 126 jsonBody, err = json.Marshal(authOKTARequest{ 127 Username: user, 128 Password: password, 129 }) 130 if err != nil { 131 return nil, err 132 } 133 respa, err := sr.FuncPostAuthOKTA(ctx, sr, headers, jsonBody, respd.Data.TokenURL, sr.LoginTimeout) 134 if err != nil { 135 return nil, err 136 } 137 138 logger.WithContext(ctx).Info("step 4: query IDP URL snowflake app to get SAML response") 139 params = &url.Values{} 140 params.Add("RelayState", "/some/deep/link") 141 var oneTimeToken string 142 if respa.SessionToken != "" { 143 oneTimeToken = respa.SessionToken 144 } else { 145 oneTimeToken = respa.CookieToken 146 } 147 params.Add("onetimetoken", oneTimeToken) 148 149 headers = make(map[string]string) 150 headers[httpHeaderAccept] = "*/*" 151 bd, err := sr.FuncGetSSO(ctx, sr, params, headers, respd.Data.SSOURL, sr.LoginTimeout) 152 if err != nil { 153 return nil, err 154 } 155 logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL") 156 tgtURL, err := postBackURL(bd) 157 if err != nil { 158 return nil, err 159 } 160 161 fullURL := sr.getURL() 162 logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL) 163 if !isPrefixEqual(tgtURL, fullURL) { 164 return nil, &SnowflakeError{ 165 Number: ErrCodeSSOURLNotMatch, 166 SQLState: SQLStateConnectionRejected, 167 Message: errMsgSSOURLNotMatch, 168 MessageArgs: []interface{}{tgtURL, fullURL}, 169 } 170 } 171 return bd, nil 172 } 173 174 func postBackURL(htmlData []byte) (url *url.URL, err error) { 175 idx0 := bytes.Index(htmlData, []byte("<form")) 176 if idx0 < 0 { 177 return nil, fmt.Errorf("failed to find a form tag in HTML response: %v", htmlData) 178 } 179 idx := bytes.Index(htmlData[idx0:], []byte("action=\"")) 180 if idx < 0 { 181 return nil, fmt.Errorf("failed to find action field in HTML response: %v", htmlData[idx0:]) 182 } 183 idx += idx0 184 endIdx := bytes.Index(htmlData[idx+8:], []byte("\"")) 185 if endIdx < 0 { 186 return nil, fmt.Errorf("failed to find the end of action field: %v", htmlData[idx+8:]) 187 } 188 r := html.UnescapeString(string(htmlData[idx+8 : idx+8+endIdx])) 189 return url.Parse(r) 190 } 191 192 func isPrefixEqual(u1 *url.URL, u2 *url.URL) bool { 193 p1 := u1.Port() 194 if p1 == "" && u1.Scheme == "https" { 195 p1 = "443" 196 } 197 p2 := u1.Port() 198 if p2 == "" && u1.Scheme == "https" { 199 p2 = "443" 200 } 201 return u1.Hostname() == u2.Hostname() && p1 == p2 && u1.Scheme == u2.Scheme 202 } 203 204 // Makes a request to /session/authenticator-request to get SAML Information, 205 // such as the IDP Url and Proof Key, depending on the authenticator 206 func postAuthSAML( 207 ctx context.Context, 208 sr *snowflakeRestful, 209 headers map[string]string, 210 body []byte, 211 timeout time.Duration) ( 212 data *authResponse, err error) { 213 214 params := &url.Values{} 215 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 216 fullURL := sr.getFullURL(authenticatorRequestPath, params) 217 218 logger.Infof("fullURL: %v", fullURL) 219 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil) 220 if err != nil { 221 return nil, err 222 } 223 defer resp.Body.Close() 224 if resp.StatusCode == http.StatusOK { 225 var respd authResponse 226 err = json.NewDecoder(resp.Body).Decode(&respd) 227 if err != nil { 228 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 229 return nil, err 230 } 231 return &respd, nil 232 } 233 switch resp.StatusCode { 234 case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: 235 // service availability or connectivity issue. Most likely server side issue. 236 return nil, &SnowflakeError{ 237 Number: ErrCodeServiceUnavailable, 238 SQLState: SQLStateConnectionWasNotEstablished, 239 Message: errMsgServiceUnavailable, 240 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 241 } 242 case http.StatusUnauthorized, http.StatusForbidden: 243 // failed to connect to db. account name may be wrong 244 return nil, &SnowflakeError{ 245 Number: ErrCodeFailedToConnect, 246 SQLState: SQLStateConnectionRejected, 247 Message: errMsgFailedToConnect, 248 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 249 } 250 } 251 _, err = io.ReadAll(resp.Body) 252 if err != nil { 253 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 254 return nil, err 255 } 256 return nil, &SnowflakeError{ 257 Number: ErrFailedToAuthSAML, 258 SQLState: SQLStateConnectionRejected, 259 Message: errMsgFailedToAuthSAML, 260 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 261 } 262 } 263 264 func postAuthOKTA( 265 ctx context.Context, 266 sr *snowflakeRestful, 267 headers map[string]string, 268 body []byte, 269 fullURL string, 270 timeout time.Duration) ( 271 data *authOKTAResponse, err error) { 272 logger.Infof("fullURL: %v", fullURL) 273 targetURL, err := url.Parse(fullURL) 274 if err != nil { 275 return nil, err 276 } 277 resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider, nil) 278 if err != nil { 279 return nil, err 280 } 281 defer resp.Body.Close() 282 if resp.StatusCode == http.StatusOK { 283 var respd authOKTAResponse 284 err = json.NewDecoder(resp.Body).Decode(&respd) 285 if err != nil { 286 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 287 return nil, err 288 } 289 return &respd, nil 290 } 291 _, err = io.ReadAll(resp.Body) 292 if err != nil { 293 logger.Errorf("failed to extract HTTP response body. err: %v", err) 294 return nil, err 295 } 296 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v", resp.StatusCode, fullURL) 297 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 298 return nil, &SnowflakeError{ 299 Number: ErrFailedToAuthOKTA, 300 SQLState: SQLStateConnectionRejected, 301 Message: errMsgFailedToAuthOKTA, 302 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 303 } 304 } 305 306 func getSSO( 307 ctx context.Context, 308 sr *snowflakeRestful, 309 params *url.Values, 310 headers map[string]string, 311 ssoURL string, 312 timeout time.Duration) ( 313 bd []byte, err error) { 314 fullURL, err := url.Parse(ssoURL) 315 if err != nil { 316 return nil, err 317 } 318 fullURL.RawQuery = params.Encode() 319 logger.WithContext(ctx).Infof("fullURL: %v", fullURL) 320 resp, err := sr.FuncGet(ctx, sr, fullURL, headers, timeout) 321 if err != nil { 322 return nil, err 323 } 324 defer resp.Body.Close() 325 b, err := io.ReadAll(resp.Body) 326 if err != nil { 327 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 328 return nil, err 329 } 330 if resp.StatusCode == http.StatusOK { 331 return b, nil 332 } 333 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v ", resp.StatusCode, fullURL) 334 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 335 return nil, &SnowflakeError{ 336 Number: ErrFailedToGetSSO, 337 SQLState: SQLStateConnectionRejected, 338 Message: errMsgFailedToGetSSO, 339 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 340 } 341 }