github.com/snowflakedb/gosnowflake@v1.9.0/authokta.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"encoding/json"
     9  	"fmt"
    10  	"html"
    11  	"io"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"time"
    16  )
    17  
    18  type authOKTARequest struct {
    19  	Username string `json:"username"`
    20  	Password string `json:"password"`
    21  }
    22  
    23  type authOKTAResponse struct {
    24  	CookieToken  string `json:"cookieToken"`
    25  	SessionToken string `json:"sessionToken"`
    26  }
    27  
    28  /*
    29  authenticateBySAML authenticates a user by SAML
    30  SAML Authentication
    31   1. query GS to obtain IDP token and SSO url
    32   2. IMPORTANT Client side validation:
    33      validate both token url and sso url contains same prefix
    34      (protocol + host + port) as the given authenticator url.
    35      Explanation:
    36      This provides a way for the user to 'authenticate' the IDP it is
    37      sending his/her credentials to.  Without such a check, the user could
    38      be coerced to provide credentials to an IDP impersonator.
    39   3. query IDP token url to authenticate and retrieve access token
    40   4. given access token, query IDP URL snowflake app to get SAML response
    41   5. IMPORTANT Client side validation:
    42      validate the post back url come back with the SAML response
    43      contains the same prefix as the Snowflake's server url, which is the
    44      intended destination url to Snowflake.
    45  
    46  Explanation:
    47  
    48  	This emulates the behavior of IDP initiated login flow in the user
    49  	browser where the IDP instructs the browser to POST the SAML
    50  	assertion to the specific SP endpoint.  This is critical in
    51  	preventing a SAML assertion issued to one SP from being sent to
    52  	another SP.
    53  */
    54  func authenticateBySAML(
    55  	ctx context.Context,
    56  	sr *snowflakeRestful,
    57  	oktaURL *url.URL,
    58  	application string,
    59  	account string,
    60  	user string,
    61  	password string,
    62  ) (samlResponse []byte, err error) {
    63  	logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url")
    64  	headers := make(map[string]string)
    65  	headers[httpHeaderContentType] = headerContentTypeApplicationJSON
    66  	headers[httpHeaderAccept] = headerContentTypeApplicationJSON
    67  	headers[httpHeaderUserAgent] = userAgent
    68  
    69  	clientEnvironment := authRequestClientEnvironment{
    70  		Application: application,
    71  		Os:          operatingSystem,
    72  		OsVersion:   platform,
    73  	}
    74  	requestMain := authRequestData{
    75  		ClientAppID:       clientType,
    76  		ClientAppVersion:  SnowflakeGoDriverVersion,
    77  		AccountName:       account,
    78  		ClientEnvironment: clientEnvironment,
    79  		Authenticator:     oktaURL.String(),
    80  	}
    81  	authRequest := authRequest{
    82  		Data: requestMain,
    83  	}
    84  	params := &url.Values{}
    85  	jsonBody, err := json.Marshal(authRequest)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	logger.WithContext(ctx).Infof("PARAMS for Auth: %v, %v", params, sr)
    90  	respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if !respd.Success {
    95  		logger.Errorln("Authentication FAILED")
    96  		sr.TokenAccessor.SetTokens("", "", -1)
    97  		code, err := strconv.Atoi(respd.Code)
    98  		if err != nil {
    99  			code = -1
   100  			return nil, err
   101  		}
   102  		return nil, &SnowflakeError{
   103  			Number:   code,
   104  			SQLState: SQLStateConnectionRejected,
   105  			Message:  respd.Message,
   106  		}
   107  	}
   108  	logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL")
   109  	var tokenURL *url.URL
   110  	var ssoURL *url.URL
   111  	if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil {
   112  		return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL)
   113  	}
   114  	if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil {
   115  		return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL)
   116  	}
   117  	if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) {
   118  		return nil, &SnowflakeError{
   119  			Number:      ErrCodeIdpConnectionError,
   120  			SQLState:    SQLStateConnectionRejected,
   121  			Message:     errMsgIdpConnectionError,
   122  			MessageArgs: []interface{}{oktaURL, respd.Data.TokenURL, respd.Data.SSOURL},
   123  		}
   124  	}
   125  	logger.WithContext(ctx).Info("step 3: query IDP token url to authenticate and retrieve access token")
   126  	jsonBody, err = json.Marshal(authOKTARequest{
   127  		Username: user,
   128  		Password: password,
   129  	})
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	respa, err := sr.FuncPostAuthOKTA(ctx, sr, headers, jsonBody, respd.Data.TokenURL, sr.LoginTimeout)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	logger.WithContext(ctx).Info("step 4: query IDP URL snowflake app to get SAML response")
   139  	params = &url.Values{}
   140  	params.Add("RelayState", "/some/deep/link")
   141  	var oneTimeToken string
   142  	if respa.SessionToken != "" {
   143  		oneTimeToken = respa.SessionToken
   144  	} else {
   145  		oneTimeToken = respa.CookieToken
   146  	}
   147  	params.Add("onetimetoken", oneTimeToken)
   148  
   149  	headers = make(map[string]string)
   150  	headers[httpHeaderAccept] = "*/*"
   151  	bd, err := sr.FuncGetSSO(ctx, sr, params, headers, respd.Data.SSOURL, sr.LoginTimeout)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL")
   156  	tgtURL, err := postBackURL(bd)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	fullURL := sr.getURL()
   162  	logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL)
   163  	if !isPrefixEqual(tgtURL, fullURL) {
   164  		return nil, &SnowflakeError{
   165  			Number:      ErrCodeSSOURLNotMatch,
   166  			SQLState:    SQLStateConnectionRejected,
   167  			Message:     errMsgSSOURLNotMatch,
   168  			MessageArgs: []interface{}{tgtURL, fullURL},
   169  		}
   170  	}
   171  	return bd, nil
   172  }
   173  
   174  func postBackURL(htmlData []byte) (url *url.URL, err error) {
   175  	idx0 := bytes.Index(htmlData, []byte("<form"))
   176  	if idx0 < 0 {
   177  		return nil, fmt.Errorf("failed to find a form tag in HTML response: %v", htmlData)
   178  	}
   179  	idx := bytes.Index(htmlData[idx0:], []byte("action=\""))
   180  	if idx < 0 {
   181  		return nil, fmt.Errorf("failed to find action field in HTML response: %v", htmlData[idx0:])
   182  	}
   183  	idx += idx0
   184  	endIdx := bytes.Index(htmlData[idx+8:], []byte("\""))
   185  	if endIdx < 0 {
   186  		return nil, fmt.Errorf("failed to find the end of action field: %v", htmlData[idx+8:])
   187  	}
   188  	r := html.UnescapeString(string(htmlData[idx+8 : idx+8+endIdx]))
   189  	return url.Parse(r)
   190  }
   191  
   192  func isPrefixEqual(u1 *url.URL, u2 *url.URL) bool {
   193  	p1 := u1.Port()
   194  	if p1 == "" && u1.Scheme == "https" {
   195  		p1 = "443"
   196  	}
   197  	p2 := u1.Port()
   198  	if p2 == "" && u1.Scheme == "https" {
   199  		p2 = "443"
   200  	}
   201  	return u1.Hostname() == u2.Hostname() && p1 == p2 && u1.Scheme == u2.Scheme
   202  }
   203  
   204  // Makes a request to /session/authenticator-request to get SAML Information,
   205  // such as the IDP Url and Proof Key, depending on the authenticator
   206  func postAuthSAML(
   207  	ctx context.Context,
   208  	sr *snowflakeRestful,
   209  	headers map[string]string,
   210  	body []byte,
   211  	timeout time.Duration) (
   212  	data *authResponse, err error) {
   213  
   214  	params := &url.Values{}
   215  	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
   216  	fullURL := sr.getFullURL(authenticatorRequestPath, params)
   217  
   218  	logger.Infof("fullURL: %v", fullURL)
   219  	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	defer resp.Body.Close()
   224  	if resp.StatusCode == http.StatusOK {
   225  		var respd authResponse
   226  		err = json.NewDecoder(resp.Body).Decode(&respd)
   227  		if err != nil {
   228  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   229  			return nil, err
   230  		}
   231  		return &respd, nil
   232  	}
   233  	switch resp.StatusCode {
   234  	case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
   235  		// service availability or connectivity issue. Most likely server side issue.
   236  		return nil, &SnowflakeError{
   237  			Number:      ErrCodeServiceUnavailable,
   238  			SQLState:    SQLStateConnectionWasNotEstablished,
   239  			Message:     errMsgServiceUnavailable,
   240  			MessageArgs: []interface{}{resp.StatusCode, fullURL},
   241  		}
   242  	case http.StatusUnauthorized, http.StatusForbidden:
   243  		// failed to connect to db. account name may be wrong
   244  		return nil, &SnowflakeError{
   245  			Number:      ErrCodeFailedToConnect,
   246  			SQLState:    SQLStateConnectionRejected,
   247  			Message:     errMsgFailedToConnect,
   248  			MessageArgs: []interface{}{resp.StatusCode, fullURL},
   249  		}
   250  	}
   251  	_, err = io.ReadAll(resp.Body)
   252  	if err != nil {
   253  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   254  		return nil, err
   255  	}
   256  	return nil, &SnowflakeError{
   257  		Number:      ErrFailedToAuthSAML,
   258  		SQLState:    SQLStateConnectionRejected,
   259  		Message:     errMsgFailedToAuthSAML,
   260  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   261  	}
   262  }
   263  
   264  func postAuthOKTA(
   265  	ctx context.Context,
   266  	sr *snowflakeRestful,
   267  	headers map[string]string,
   268  	body []byte,
   269  	fullURL string,
   270  	timeout time.Duration) (
   271  	data *authOKTAResponse, err error) {
   272  	logger.Infof("fullURL: %v", fullURL)
   273  	targetURL, err := url.Parse(fullURL)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  	resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider, nil)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  	defer resp.Body.Close()
   282  	if resp.StatusCode == http.StatusOK {
   283  		var respd authOKTAResponse
   284  		err = json.NewDecoder(resp.Body).Decode(&respd)
   285  		if err != nil {
   286  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   287  			return nil, err
   288  		}
   289  		return &respd, nil
   290  	}
   291  	_, err = io.ReadAll(resp.Body)
   292  	if err != nil {
   293  		logger.Errorf("failed to extract HTTP response body. err: %v", err)
   294  		return nil, err
   295  	}
   296  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v", resp.StatusCode, fullURL)
   297  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   298  	return nil, &SnowflakeError{
   299  		Number:      ErrFailedToAuthOKTA,
   300  		SQLState:    SQLStateConnectionRejected,
   301  		Message:     errMsgFailedToAuthOKTA,
   302  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   303  	}
   304  }
   305  
   306  func getSSO(
   307  	ctx context.Context,
   308  	sr *snowflakeRestful,
   309  	params *url.Values,
   310  	headers map[string]string,
   311  	ssoURL string,
   312  	timeout time.Duration) (
   313  	bd []byte, err error) {
   314  	fullURL, err := url.Parse(ssoURL)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  	fullURL.RawQuery = params.Encode()
   319  	logger.WithContext(ctx).Infof("fullURL: %v", fullURL)
   320  	resp, err := sr.FuncGet(ctx, sr, fullURL, headers, timeout)
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  	defer resp.Body.Close()
   325  	b, err := io.ReadAll(resp.Body)
   326  	if err != nil {
   327  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   328  		return nil, err
   329  	}
   330  	if resp.StatusCode == http.StatusOK {
   331  		return b, nil
   332  	}
   333  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v ", resp.StatusCode, fullURL)
   334  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   335  	return nil, &SnowflakeError{
   336  		Number:      ErrFailedToGetSSO,
   337  		SQLState:    SQLStateConnectionRejected,
   338  		Message:     errMsgFailedToGetSSO,
   339  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   340  	}
   341  }