github.com/snowflakedb/gosnowflake@v1.9.0/restful.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"strconv"
    13  	"time"
    14  )
    15  
    16  // HTTP headers
    17  const (
    18  	headerSnowflakeToken   = "Snowflake Token=\"%v\""
    19  	headerAuthorizationKey = "Authorization"
    20  
    21  	headerContentTypeApplicationJSON     = "application/json"
    22  	headerAcceptTypeApplicationSnowflake = "application/snowflake"
    23  )
    24  
    25  // Snowflake Server Error code
    26  const (
    27  	queryInProgressCode      = "333333"
    28  	queryInProgressAsyncCode = "333334"
    29  	sessionExpiredCode       = "390112"
    30  	queryNotExecuting        = "000605"
    31  )
    32  
    33  // Snowflake Server Endpoints
    34  const (
    35  	loginRequestPath         = "/session/v1/login-request"
    36  	queryRequestPath         = "/queries/v1/query-request"
    37  	tokenRequestPath         = "/session/token-request"
    38  	abortRequestPath         = "/queries/v1/abort-request"
    39  	authenticatorRequestPath = "/session/authenticator-request"
    40  	monitoringQueriesPath    = "/monitoring/queries"
    41  	sessionRequestPath       = "/session"
    42  	heartBeatPath            = "/session/heartbeat"
    43  	consoleLoginRequestPath  = "/console/login"
    44  )
    45  
    46  type (
    47  	funcGetType      func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
    48  	funcPostType     func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error)
    49  	funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, int) (*http.Response, error)
    50  	bodyCreatorType  func() ([]byte, error)
    51  )
    52  
    53  var emptyBodyCreator = func() ([]byte, error) {
    54  	return []byte{}, nil
    55  }
    56  
    57  type snowflakeRestful struct {
    58  	Host           string
    59  	Port           int
    60  	Protocol       string
    61  	LoginTimeout   time.Duration // Login timeout
    62  	RequestTimeout time.Duration // request timeout
    63  	MaxRetryCount  int
    64  
    65  	Client        *http.Client
    66  	JWTClient     *http.Client
    67  	TokenAccessor TokenAccessor
    68  	HeartBeat     *heartbeat
    69  
    70  	Connection *snowflakeConn
    71  
    72  	FuncPostQuery       func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error)
    73  	FuncPostQueryHelper func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error)
    74  	FuncPost            funcPostType
    75  	FuncGet             funcGetType
    76  	FuncAuthPost        funcAuthPostType
    77  	FuncRenewSession    func(context.Context, *snowflakeRestful, time.Duration) error
    78  	FuncCloseSession    func(context.Context, *snowflakeRestful, time.Duration) error
    79  	FuncCancelQuery     func(context.Context, *snowflakeRestful, UUID, time.Duration) error
    80  
    81  	FuncPostAuth     func(context.Context, *snowflakeRestful, *http.Client, *url.Values, map[string]string, bodyCreatorType, time.Duration) (*authResponse, error)
    82  	FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error)
    83  	FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error)
    84  	FuncGetSSO       func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error)
    85  }
    86  
    87  func (sr *snowflakeRestful) getURL() *url.URL {
    88  	return &url.URL{
    89  		Scheme: sr.Protocol,
    90  		Host:   sr.Host + ":" + strconv.Itoa(sr.Port),
    91  	}
    92  }
    93  
    94  func (sr *snowflakeRestful) getFullURL(path string, params *url.Values) *url.URL {
    95  	ret := &url.URL{
    96  		Scheme: sr.Protocol,
    97  		Host:   sr.Host + ":" + strconv.Itoa(sr.Port),
    98  		Path:   path,
    99  	}
   100  	if params != nil {
   101  		ret.RawQuery = params.Encode()
   102  	}
   103  	return ret
   104  }
   105  
   106  // We need separate client for JWT, because if token processing takes too long, token may be already expired.
   107  func (sr *snowflakeRestful) getClientFor(authType AuthType) *http.Client {
   108  	switch authType {
   109  	case AuthTypeJwt:
   110  		return sr.JWTClient
   111  	default:
   112  		return sr.Client
   113  	}
   114  }
   115  
   116  // Renew the snowflake session if the current token is still the stale token specified
   117  func (sr *snowflakeRestful) renewExpiredSessionToken(ctx context.Context, timeout time.Duration, expiredToken string) error {
   118  	err := sr.TokenAccessor.Lock()
   119  	if err != nil {
   120  		return err
   121  	}
   122  	defer sr.TokenAccessor.Unlock()
   123  	currentToken, _, _ := sr.TokenAccessor.GetTokens()
   124  	if expiredToken == currentToken || currentToken == "" {
   125  		// Only renew the session if the current token is still the expired token or current token is empty
   126  		return sr.FuncRenewSession(ctx, sr, timeout)
   127  	}
   128  	return nil
   129  }
   130  
   131  type renewSessionResponse struct {
   132  	Data    renewSessionResponseMain `json:"data"`
   133  	Message string                   `json:"message"`
   134  	Code    string                   `json:"code"`
   135  	Success bool                     `json:"success"`
   136  }
   137  
   138  type renewSessionResponseMain struct {
   139  	SessionToken        string        `json:"sessionToken"`
   140  	ValidityInSecondsST time.Duration `json:"validityInSecondsST"`
   141  	MasterToken         string        `json:"masterToken"`
   142  	ValidityInSecondsMT time.Duration `json:"validityInSecondsMT"`
   143  	SessionID           int64         `json:"sessionId"`
   144  }
   145  
   146  type cancelQueryResponse struct {
   147  	Data    interface{} `json:"data"`
   148  	Message string      `json:"message"`
   149  	Code    string      `json:"code"`
   150  	Success bool        `json:"success"`
   151  }
   152  
   153  type telemetryResponse struct {
   154  	Data    interface{}       `json:"data,omitempty"`
   155  	Message string            `json:"message"`
   156  	Code    string            `json:"code"`
   157  	Success bool              `json:"success"`
   158  	Headers map[string]string `json:"headers,omitempty"`
   159  }
   160  
   161  func postRestful(
   162  	ctx context.Context,
   163  	sr *snowflakeRestful,
   164  	fullURL *url.URL,
   165  	headers map[string]string,
   166  	body []byte,
   167  	timeout time.Duration,
   168  	currentTimeProvider currentTimeProvider,
   169  	cfg *Config) (
   170  	*http.Response, error) {
   171  	return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, currentTimeProvider, cfg).
   172  		doPost().
   173  		setBody(body).
   174  		execute()
   175  }
   176  
   177  func getRestful(
   178  	ctx context.Context,
   179  	sr *snowflakeRestful,
   180  	fullURL *url.URL,
   181  	headers map[string]string,
   182  	timeout time.Duration) (
   183  	*http.Response, error) {
   184  	return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, defaultTimeProvider, nil).execute()
   185  }
   186  
   187  func postAuthRestful(
   188  	ctx context.Context,
   189  	client *http.Client,
   190  	fullURL *url.URL,
   191  	headers map[string]string,
   192  	bodyCreator bodyCreatorType,
   193  	timeout time.Duration,
   194  	maxRetryCount int) (
   195  	*http.Response, error) {
   196  	return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, maxRetryCount, defaultTimeProvider, nil).
   197  		doPost().
   198  		setBodyCreator(bodyCreator).
   199  		execute()
   200  }
   201  
   202  func postRestfulQuery(
   203  	ctx context.Context,
   204  	sr *snowflakeRestful,
   205  	params *url.Values,
   206  	headers map[string]string,
   207  	body []byte,
   208  	timeout time.Duration,
   209  	requestID UUID,
   210  	cfg *Config) (
   211  	data *execResponse, err error) {
   212  
   213  	data, err = sr.FuncPostQueryHelper(ctx, sr, params, headers, body, timeout, requestID, cfg)
   214  
   215  	// errors other than context timeout and cancel would be returned to upper layers
   216  	if err != context.Canceled && err != context.DeadlineExceeded {
   217  		return data, err
   218  	}
   219  
   220  	if err = sr.FuncCancelQuery(context.Background(), sr, requestID, timeout); err != nil {
   221  		return nil, err
   222  	}
   223  	return nil, ctx.Err()
   224  }
   225  
   226  func postRestfulQueryHelper(
   227  	ctx context.Context,
   228  	sr *snowflakeRestful,
   229  	params *url.Values,
   230  	headers map[string]string,
   231  	body []byte,
   232  	timeout time.Duration,
   233  	requestID UUID,
   234  	cfg *Config) (
   235  	data *execResponse, err error) {
   236  	logger.Infof("params: %v", params)
   237  	params.Add(requestIDKey, requestID.String())
   238  	params.Add(requestGUIDKey, NewUUID().String())
   239  	token, _, _ := sr.TokenAccessor.GetTokens()
   240  	if token != "" {
   241  		headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
   242  	}
   243  
   244  	var resp *http.Response
   245  	fullURL := sr.getFullURL(queryRequestPath, params)
   246  	resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, cfg)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	defer resp.Body.Close()
   251  
   252  	if resp.StatusCode == http.StatusOK {
   253  		logger.WithContext(ctx).Infof("postQuery: resp: %v", resp)
   254  		var respd execResponse
   255  		if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil {
   256  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   257  			return nil, err
   258  		}
   259  		if respd.Code == sessionExpiredCode {
   260  			if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil {
   261  				return nil, err
   262  			}
   263  			return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg)
   264  		}
   265  
   266  		if queryIDChan := getQueryIDChan(ctx); queryIDChan != nil {
   267  			queryIDChan <- respd.Data.QueryID
   268  			close(queryIDChan)
   269  			ctx = WithQueryIDChan(ctx, nil)
   270  		}
   271  
   272  		isSessionRenewed := false
   273  
   274  		// if asynchronous query in progress, kick off retrieval but return object
   275  		if respd.Code == queryInProgressAsyncCode && isAsyncMode(ctx) {
   276  			return sr.processAsync(ctx, &respd, headers, timeout, cfg)
   277  		}
   278  		for isSessionRenewed || respd.Code == queryInProgressCode ||
   279  			respd.Code == queryInProgressAsyncCode {
   280  			if !isSessionRenewed {
   281  				fullURL = sr.getFullURL(respd.Data.GetResultURL, nil)
   282  			}
   283  
   284  			logger.Info("ping pong")
   285  			token, _, _ = sr.TokenAccessor.GetTokens()
   286  			headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
   287  
   288  			resp, err = sr.FuncGet(ctx, sr, fullURL, headers, timeout)
   289  			if err != nil {
   290  				logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
   291  				return nil, err
   292  			}
   293  			respd = execResponse{} // reset the response
   294  			err = json.NewDecoder(resp.Body).Decode(&respd)
   295  			resp.Body.Close()
   296  			if err != nil {
   297  				logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   298  				return nil, err
   299  			}
   300  			if respd.Code == sessionExpiredCode {
   301  				if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil {
   302  					return nil, err
   303  				}
   304  				isSessionRenewed = true
   305  			} else {
   306  				isSessionRenewed = false
   307  			}
   308  		}
   309  		return &respd, nil
   310  	}
   311  	b, err := io.ReadAll(resp.Body)
   312  	if err != nil {
   313  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   314  		return nil, err
   315  	}
   316  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
   317  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   318  	return nil, &SnowflakeError{
   319  		Number:      ErrFailedToPostQuery,
   320  		SQLState:    SQLStateConnectionFailure,
   321  		Message:     errMsgFailedToPostQuery,
   322  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   323  	}
   324  }
   325  
   326  func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error {
   327  	logger.WithContext(ctx).Info("close session")
   328  	params := &url.Values{}
   329  	params.Add("delete", "true")
   330  	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
   331  	params.Add(requestGUIDKey, NewUUID().String())
   332  	fullURL := sr.getFullURL(sessionRequestPath, params)
   333  
   334  	headers := getHeaders()
   335  	token, _, _ := sr.TokenAccessor.GetTokens()
   336  	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
   337  
   338  	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider, nil)
   339  	if err != nil {
   340  		return err
   341  	}
   342  	defer resp.Body.Close()
   343  	if resp.StatusCode == http.StatusOK {
   344  		var respd renewSessionResponse
   345  		if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil {
   346  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   347  			return err
   348  		}
   349  		if !respd.Success && respd.Code != sessionExpiredCode {
   350  			c, err := strconv.Atoi(respd.Code)
   351  			if err != nil {
   352  				return err
   353  			}
   354  			return &SnowflakeError{
   355  				Number:  c,
   356  				Message: respd.Message,
   357  			}
   358  		}
   359  		return nil
   360  	}
   361  	b, err := io.ReadAll(resp.Body)
   362  	if err != nil {
   363  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   364  		return err
   365  	}
   366  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
   367  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   368  	return &SnowflakeError{
   369  		Number:      ErrFailedToCloseSession,
   370  		SQLState:    SQLStateConnectionFailure,
   371  		Message:     errMsgFailedToCloseSession,
   372  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   373  	}
   374  }
   375  
   376  func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error {
   377  	logger.WithContext(ctx).Info("start renew session")
   378  	params := &url.Values{}
   379  	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
   380  	params.Add(requestGUIDKey, NewUUID().String())
   381  	fullURL := sr.getFullURL(tokenRequestPath, params)
   382  
   383  	token, masterToken, _ := sr.TokenAccessor.GetTokens()
   384  	headers := getHeaders()
   385  	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, masterToken)
   386  
   387  	body := make(map[string]string)
   388  	body["oldSessionToken"] = token
   389  	body["requestType"] = "RENEW"
   390  
   391  	var reqBody []byte
   392  	reqBody, err := json.Marshal(body)
   393  	if err != nil {
   394  		return err
   395  	}
   396  
   397  	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider, nil)
   398  	if err != nil {
   399  		return err
   400  	}
   401  	defer resp.Body.Close()
   402  	if resp.StatusCode == http.StatusOK {
   403  		var respd renewSessionResponse
   404  		err = json.NewDecoder(resp.Body).Decode(&respd)
   405  		if err != nil {
   406  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   407  			return err
   408  		}
   409  		if !respd.Success {
   410  			c, err := strconv.Atoi(respd.Code)
   411  			if err != nil {
   412  				return err
   413  			}
   414  			return &SnowflakeError{
   415  				Number:  c,
   416  				Message: respd.Message,
   417  			}
   418  		}
   419  		sr.TokenAccessor.SetTokens(respd.Data.SessionToken, respd.Data.MasterToken, respd.Data.SessionID)
   420  		return nil
   421  	}
   422  	b, err := io.ReadAll(resp.Body)
   423  	if err != nil {
   424  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   425  		return err
   426  	}
   427  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
   428  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   429  	return &SnowflakeError{
   430  		Number:      ErrFailedToRenewSession,
   431  		SQLState:    SQLStateConnectionFailure,
   432  		Message:     errMsgFailedToRenew,
   433  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   434  	}
   435  }
   436  
   437  func getCancelRetry(ctx context.Context) int {
   438  	val := ctx.Value(cancelRetry)
   439  	if val == nil {
   440  		return 5
   441  	}
   442  	cnt, ok := val.(int)
   443  	if !ok {
   444  		return -1
   445  	}
   446  	return cnt
   447  }
   448  
   449  func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, timeout time.Duration) error {
   450  	logger.WithContext(ctx).Info("cancel query")
   451  	params := &url.Values{}
   452  	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
   453  	params.Add(requestGUIDKey, NewUUID().String())
   454  
   455  	fullURL := sr.getFullURL(abortRequestPath, params)
   456  
   457  	headers := getHeaders()
   458  	token, _, _ := sr.TokenAccessor.GetTokens()
   459  	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
   460  
   461  	req := make(map[string]string)
   462  	req[requestIDKey] = requestID.String()
   463  
   464  	reqByte, err := json.Marshal(req)
   465  	if err != nil {
   466  		return err
   467  	}
   468  
   469  	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider, nil)
   470  	if err != nil {
   471  		return err
   472  	}
   473  	defer resp.Body.Close()
   474  	if resp.StatusCode == http.StatusOK {
   475  		var respd cancelQueryResponse
   476  		if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil {
   477  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   478  			return err
   479  		}
   480  		ctxRetry := getCancelRetry(ctx)
   481  		if !respd.Success && respd.Code == sessionExpiredCode {
   482  			if err = sr.FuncRenewSession(ctx, sr, timeout); err != nil {
   483  				return err
   484  			}
   485  			return sr.FuncCancelQuery(ctx, sr, requestID, timeout)
   486  		} else if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 {
   487  			return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout)
   488  		} else if respd.Success {
   489  			return nil
   490  		} else {
   491  			c, err := strconv.Atoi(respd.Code)
   492  			if err != nil {
   493  				return err
   494  			}
   495  			return &SnowflakeError{
   496  				Number:  c,
   497  				Message: respd.Message,
   498  			}
   499  		}
   500  	}
   501  	b, err := io.ReadAll(resp.Body)
   502  	if err != nil {
   503  		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
   504  		return err
   505  	}
   506  	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
   507  	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
   508  	return &SnowflakeError{
   509  		Number:      ErrFailedToCancelQuery,
   510  		SQLState:    SQLStateConnectionFailure,
   511  		Message:     errMsgFailedToCancelQuery,
   512  		MessageArgs: []interface{}{resp.StatusCode, fullURL},
   513  	}
   514  }
   515  
   516  func getQueryIDChan(ctx context.Context) chan<- string {
   517  	v := ctx.Value(queryIDChannel)
   518  	if v == nil {
   519  		return nil
   520  	}
   521  	c, ok := v.(chan<- string)
   522  	if !ok {
   523  		return nil
   524  	}
   525  	return c
   526  }