github.com/snowflakedb/gosnowflake@v1.9.0/restful.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "encoding/json" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "strconv" 13 "time" 14 ) 15 16 // HTTP headers 17 const ( 18 headerSnowflakeToken = "Snowflake Token=\"%v\"" 19 headerAuthorizationKey = "Authorization" 20 21 headerContentTypeApplicationJSON = "application/json" 22 headerAcceptTypeApplicationSnowflake = "application/snowflake" 23 ) 24 25 // Snowflake Server Error code 26 const ( 27 queryInProgressCode = "333333" 28 queryInProgressAsyncCode = "333334" 29 sessionExpiredCode = "390112" 30 queryNotExecuting = "000605" 31 ) 32 33 // Snowflake Server Endpoints 34 const ( 35 loginRequestPath = "/session/v1/login-request" 36 queryRequestPath = "/queries/v1/query-request" 37 tokenRequestPath = "/session/token-request" 38 abortRequestPath = "/queries/v1/abort-request" 39 authenticatorRequestPath = "/session/authenticator-request" 40 monitoringQueriesPath = "/monitoring/queries" 41 sessionRequestPath = "/session" 42 heartBeatPath = "/session/heartbeat" 43 consoleLoginRequestPath = "/console/login" 44 ) 45 46 type ( 47 funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) 48 funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error) 49 funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, int) (*http.Response, error) 50 bodyCreatorType func() ([]byte, error) 51 ) 52 53 var emptyBodyCreator = func() ([]byte, error) { 54 return []byte{}, nil 55 } 56 57 type snowflakeRestful struct { 58 Host string 59 Port int 60 Protocol string 61 LoginTimeout time.Duration // Login timeout 62 RequestTimeout time.Duration // request timeout 63 MaxRetryCount int 64 65 Client *http.Client 66 JWTClient *http.Client 67 TokenAccessor TokenAccessor 68 HeartBeat *heartbeat 69 70 Connection *snowflakeConn 71 72 FuncPostQuery func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error) 73 FuncPostQueryHelper func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error) 74 FuncPost funcPostType 75 FuncGet funcGetType 76 FuncAuthPost funcAuthPostType 77 FuncRenewSession func(context.Context, *snowflakeRestful, time.Duration) error 78 FuncCloseSession func(context.Context, *snowflakeRestful, time.Duration) error 79 FuncCancelQuery func(context.Context, *snowflakeRestful, UUID, time.Duration) error 80 81 FuncPostAuth func(context.Context, *snowflakeRestful, *http.Client, *url.Values, map[string]string, bodyCreatorType, time.Duration) (*authResponse, error) 82 FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error) 83 FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error) 84 FuncGetSSO func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error) 85 } 86 87 func (sr *snowflakeRestful) getURL() *url.URL { 88 return &url.URL{ 89 Scheme: sr.Protocol, 90 Host: sr.Host + ":" + strconv.Itoa(sr.Port), 91 } 92 } 93 94 func (sr *snowflakeRestful) getFullURL(path string, params *url.Values) *url.URL { 95 ret := &url.URL{ 96 Scheme: sr.Protocol, 97 Host: sr.Host + ":" + strconv.Itoa(sr.Port), 98 Path: path, 99 } 100 if params != nil { 101 ret.RawQuery = params.Encode() 102 } 103 return ret 104 } 105 106 // We need separate client for JWT, because if token processing takes too long, token may be already expired. 107 func (sr *snowflakeRestful) getClientFor(authType AuthType) *http.Client { 108 switch authType { 109 case AuthTypeJwt: 110 return sr.JWTClient 111 default: 112 return sr.Client 113 } 114 } 115 116 // Renew the snowflake session if the current token is still the stale token specified 117 func (sr *snowflakeRestful) renewExpiredSessionToken(ctx context.Context, timeout time.Duration, expiredToken string) error { 118 err := sr.TokenAccessor.Lock() 119 if err != nil { 120 return err 121 } 122 defer sr.TokenAccessor.Unlock() 123 currentToken, _, _ := sr.TokenAccessor.GetTokens() 124 if expiredToken == currentToken || currentToken == "" { 125 // Only renew the session if the current token is still the expired token or current token is empty 126 return sr.FuncRenewSession(ctx, sr, timeout) 127 } 128 return nil 129 } 130 131 type renewSessionResponse struct { 132 Data renewSessionResponseMain `json:"data"` 133 Message string `json:"message"` 134 Code string `json:"code"` 135 Success bool `json:"success"` 136 } 137 138 type renewSessionResponseMain struct { 139 SessionToken string `json:"sessionToken"` 140 ValidityInSecondsST time.Duration `json:"validityInSecondsST"` 141 MasterToken string `json:"masterToken"` 142 ValidityInSecondsMT time.Duration `json:"validityInSecondsMT"` 143 SessionID int64 `json:"sessionId"` 144 } 145 146 type cancelQueryResponse struct { 147 Data interface{} `json:"data"` 148 Message string `json:"message"` 149 Code string `json:"code"` 150 Success bool `json:"success"` 151 } 152 153 type telemetryResponse struct { 154 Data interface{} `json:"data,omitempty"` 155 Message string `json:"message"` 156 Code string `json:"code"` 157 Success bool `json:"success"` 158 Headers map[string]string `json:"headers,omitempty"` 159 } 160 161 func postRestful( 162 ctx context.Context, 163 sr *snowflakeRestful, 164 fullURL *url.URL, 165 headers map[string]string, 166 body []byte, 167 timeout time.Duration, 168 currentTimeProvider currentTimeProvider, 169 cfg *Config) ( 170 *http.Response, error) { 171 return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, currentTimeProvider, cfg). 172 doPost(). 173 setBody(body). 174 execute() 175 } 176 177 func getRestful( 178 ctx context.Context, 179 sr *snowflakeRestful, 180 fullURL *url.URL, 181 headers map[string]string, 182 timeout time.Duration) ( 183 *http.Response, error) { 184 return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, defaultTimeProvider, nil).execute() 185 } 186 187 func postAuthRestful( 188 ctx context.Context, 189 client *http.Client, 190 fullURL *url.URL, 191 headers map[string]string, 192 bodyCreator bodyCreatorType, 193 timeout time.Duration, 194 maxRetryCount int) ( 195 *http.Response, error) { 196 return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, maxRetryCount, defaultTimeProvider, nil). 197 doPost(). 198 setBodyCreator(bodyCreator). 199 execute() 200 } 201 202 func postRestfulQuery( 203 ctx context.Context, 204 sr *snowflakeRestful, 205 params *url.Values, 206 headers map[string]string, 207 body []byte, 208 timeout time.Duration, 209 requestID UUID, 210 cfg *Config) ( 211 data *execResponse, err error) { 212 213 data, err = sr.FuncPostQueryHelper(ctx, sr, params, headers, body, timeout, requestID, cfg) 214 215 // errors other than context timeout and cancel would be returned to upper layers 216 if err != context.Canceled && err != context.DeadlineExceeded { 217 return data, err 218 } 219 220 if err = sr.FuncCancelQuery(context.Background(), sr, requestID, timeout); err != nil { 221 return nil, err 222 } 223 return nil, ctx.Err() 224 } 225 226 func postRestfulQueryHelper( 227 ctx context.Context, 228 sr *snowflakeRestful, 229 params *url.Values, 230 headers map[string]string, 231 body []byte, 232 timeout time.Duration, 233 requestID UUID, 234 cfg *Config) ( 235 data *execResponse, err error) { 236 logger.Infof("params: %v", params) 237 params.Add(requestIDKey, requestID.String()) 238 params.Add(requestGUIDKey, NewUUID().String()) 239 token, _, _ := sr.TokenAccessor.GetTokens() 240 if token != "" { 241 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 242 } 243 244 var resp *http.Response 245 fullURL := sr.getFullURL(queryRequestPath, params) 246 resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, cfg) 247 if err != nil { 248 return nil, err 249 } 250 defer resp.Body.Close() 251 252 if resp.StatusCode == http.StatusOK { 253 logger.WithContext(ctx).Infof("postQuery: resp: %v", resp) 254 var respd execResponse 255 if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { 256 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 257 return nil, err 258 } 259 if respd.Code == sessionExpiredCode { 260 if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { 261 return nil, err 262 } 263 return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg) 264 } 265 266 if queryIDChan := getQueryIDChan(ctx); queryIDChan != nil { 267 queryIDChan <- respd.Data.QueryID 268 close(queryIDChan) 269 ctx = WithQueryIDChan(ctx, nil) 270 } 271 272 isSessionRenewed := false 273 274 // if asynchronous query in progress, kick off retrieval but return object 275 if respd.Code == queryInProgressAsyncCode && isAsyncMode(ctx) { 276 return sr.processAsync(ctx, &respd, headers, timeout, cfg) 277 } 278 for isSessionRenewed || respd.Code == queryInProgressCode || 279 respd.Code == queryInProgressAsyncCode { 280 if !isSessionRenewed { 281 fullURL = sr.getFullURL(respd.Data.GetResultURL, nil) 282 } 283 284 logger.Info("ping pong") 285 token, _, _ = sr.TokenAccessor.GetTokens() 286 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 287 288 resp, err = sr.FuncGet(ctx, sr, fullURL, headers, timeout) 289 if err != nil { 290 logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) 291 return nil, err 292 } 293 respd = execResponse{} // reset the response 294 err = json.NewDecoder(resp.Body).Decode(&respd) 295 resp.Body.Close() 296 if err != nil { 297 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 298 return nil, err 299 } 300 if respd.Code == sessionExpiredCode { 301 if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { 302 return nil, err 303 } 304 isSessionRenewed = true 305 } else { 306 isSessionRenewed = false 307 } 308 } 309 return &respd, nil 310 } 311 b, err := io.ReadAll(resp.Body) 312 if err != nil { 313 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 314 return nil, err 315 } 316 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 317 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 318 return nil, &SnowflakeError{ 319 Number: ErrFailedToPostQuery, 320 SQLState: SQLStateConnectionFailure, 321 Message: errMsgFailedToPostQuery, 322 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 323 } 324 } 325 326 func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { 327 logger.WithContext(ctx).Info("close session") 328 params := &url.Values{} 329 params.Add("delete", "true") 330 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 331 params.Add(requestGUIDKey, NewUUID().String()) 332 fullURL := sr.getFullURL(sessionRequestPath, params) 333 334 headers := getHeaders() 335 token, _, _ := sr.TokenAccessor.GetTokens() 336 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 337 338 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider, nil) 339 if err != nil { 340 return err 341 } 342 defer resp.Body.Close() 343 if resp.StatusCode == http.StatusOK { 344 var respd renewSessionResponse 345 if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { 346 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 347 return err 348 } 349 if !respd.Success && respd.Code != sessionExpiredCode { 350 c, err := strconv.Atoi(respd.Code) 351 if err != nil { 352 return err 353 } 354 return &SnowflakeError{ 355 Number: c, 356 Message: respd.Message, 357 } 358 } 359 return nil 360 } 361 b, err := io.ReadAll(resp.Body) 362 if err != nil { 363 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 364 return err 365 } 366 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 367 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 368 return &SnowflakeError{ 369 Number: ErrFailedToCloseSession, 370 SQLState: SQLStateConnectionFailure, 371 Message: errMsgFailedToCloseSession, 372 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 373 } 374 } 375 376 func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { 377 logger.WithContext(ctx).Info("start renew session") 378 params := &url.Values{} 379 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 380 params.Add(requestGUIDKey, NewUUID().String()) 381 fullURL := sr.getFullURL(tokenRequestPath, params) 382 383 token, masterToken, _ := sr.TokenAccessor.GetTokens() 384 headers := getHeaders() 385 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, masterToken) 386 387 body := make(map[string]string) 388 body["oldSessionToken"] = token 389 body["requestType"] = "RENEW" 390 391 var reqBody []byte 392 reqBody, err := json.Marshal(body) 393 if err != nil { 394 return err 395 } 396 397 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider, nil) 398 if err != nil { 399 return err 400 } 401 defer resp.Body.Close() 402 if resp.StatusCode == http.StatusOK { 403 var respd renewSessionResponse 404 err = json.NewDecoder(resp.Body).Decode(&respd) 405 if err != nil { 406 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 407 return err 408 } 409 if !respd.Success { 410 c, err := strconv.Atoi(respd.Code) 411 if err != nil { 412 return err 413 } 414 return &SnowflakeError{ 415 Number: c, 416 Message: respd.Message, 417 } 418 } 419 sr.TokenAccessor.SetTokens(respd.Data.SessionToken, respd.Data.MasterToken, respd.Data.SessionID) 420 return nil 421 } 422 b, err := io.ReadAll(resp.Body) 423 if err != nil { 424 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 425 return err 426 } 427 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 428 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 429 return &SnowflakeError{ 430 Number: ErrFailedToRenewSession, 431 SQLState: SQLStateConnectionFailure, 432 Message: errMsgFailedToRenew, 433 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 434 } 435 } 436 437 func getCancelRetry(ctx context.Context) int { 438 val := ctx.Value(cancelRetry) 439 if val == nil { 440 return 5 441 } 442 cnt, ok := val.(int) 443 if !ok { 444 return -1 445 } 446 return cnt 447 } 448 449 func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, timeout time.Duration) error { 450 logger.WithContext(ctx).Info("cancel query") 451 params := &url.Values{} 452 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 453 params.Add(requestGUIDKey, NewUUID().String()) 454 455 fullURL := sr.getFullURL(abortRequestPath, params) 456 457 headers := getHeaders() 458 token, _, _ := sr.TokenAccessor.GetTokens() 459 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 460 461 req := make(map[string]string) 462 req[requestIDKey] = requestID.String() 463 464 reqByte, err := json.Marshal(req) 465 if err != nil { 466 return err 467 } 468 469 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider, nil) 470 if err != nil { 471 return err 472 } 473 defer resp.Body.Close() 474 if resp.StatusCode == http.StatusOK { 475 var respd cancelQueryResponse 476 if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { 477 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 478 return err 479 } 480 ctxRetry := getCancelRetry(ctx) 481 if !respd.Success && respd.Code == sessionExpiredCode { 482 if err = sr.FuncRenewSession(ctx, sr, timeout); err != nil { 483 return err 484 } 485 return sr.FuncCancelQuery(ctx, sr, requestID, timeout) 486 } else if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 { 487 return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout) 488 } else if respd.Success { 489 return nil 490 } else { 491 c, err := strconv.Atoi(respd.Code) 492 if err != nil { 493 return err 494 } 495 return &SnowflakeError{ 496 Number: c, 497 Message: respd.Message, 498 } 499 } 500 } 501 b, err := io.ReadAll(resp.Body) 502 if err != nil { 503 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 504 return err 505 } 506 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 507 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 508 return &SnowflakeError{ 509 Number: ErrFailedToCancelQuery, 510 SQLState: SQLStateConnectionFailure, 511 Message: errMsgFailedToCancelQuery, 512 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 513 } 514 } 515 516 func getQueryIDChan(ctx context.Context) chan<- string { 517 v := ctx.Value(queryIDChannel) 518 if v == nil { 519 return nil 520 } 521 c, ok := v.(chan<- string) 522 if !ok { 523 return nil 524 } 525 return c 526 }