github.com/snowflakedb/gosnowflake@v1.9.0/async.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"net/url"
    10  	"strconv"
    11  	"time"
    12  )
    13  
    14  func (sr *snowflakeRestful) processAsync(
    15  	ctx context.Context,
    16  	respd *execResponse,
    17  	headers map[string]string,
    18  	timeout time.Duration,
    19  	cfg *Config) (*execResponse, error) {
    20  	// placeholder object to return to user while retrieving results
    21  	rows := new(snowflakeRows)
    22  	res := new(snowflakeResult)
    23  	switch resType := getResultType(ctx); resType {
    24  	case execResultType:
    25  		res.queryID = respd.Data.QueryID
    26  		res.status = QueryStatusInProgress
    27  		res.errChannel = make(chan error)
    28  		respd.Data.AsyncResult = res
    29  	case queryResultType:
    30  		rows.queryID = respd.Data.QueryID
    31  		rows.status = QueryStatusInProgress
    32  		rows.errChannel = make(chan error)
    33  		respd.Data.AsyncRows = rows
    34  	default:
    35  		return respd, nil
    36  	}
    37  
    38  	// spawn goroutine to retrieve asynchronous results
    39  	go sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
    40  	return respd, nil
    41  }
    42  
    43  func (sr *snowflakeRestful) getAsync(
    44  	ctx context.Context,
    45  	headers map[string]string,
    46  	URL *url.URL,
    47  	timeout time.Duration,
    48  	res *snowflakeResult,
    49  	rows *snowflakeRows,
    50  	cfg *Config) error {
    51  	resType := getResultType(ctx)
    52  	var errChannel chan error
    53  	sfError := &SnowflakeError{
    54  		Number: ErrAsync,
    55  	}
    56  	if resType == execResultType {
    57  		errChannel = res.errChannel
    58  		sfError.QueryID = res.queryID
    59  	} else {
    60  		errChannel = rows.errChannel
    61  		sfError.QueryID = rows.queryID
    62  	}
    63  	defer close(errChannel)
    64  	token, _, _ := sr.TokenAccessor.GetTokens()
    65  	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
    66  
    67  	respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sr, URL, headers, timeout)
    68  	if err != nil {
    69  		logger.WithContext(ctx).Errorf("error: %v", err)
    70  		sfError.Message = err.Error()
    71  		errChannel <- sfError
    72  		return err
    73  	}
    74  
    75  	sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init(), currentTimeProvider: defaultTimeProvider}
    76  	if respd.Success {
    77  		if resType == execResultType {
    78  			res.insertID = -1
    79  			if isDml(respd.Data.StatementTypeID) {
    80  				res.affectedRows, err = updateRows(respd.Data)
    81  				if err != nil {
    82  					return err
    83  				}
    84  			} else if isMultiStmt(&respd.Data) {
    85  				r, err := sc.handleMultiExec(ctx, respd.Data)
    86  				if err != nil {
    87  					res.errChannel <- err
    88  					return err
    89  				}
    90  				res.affectedRows, err = r.RowsAffected()
    91  				if err != nil {
    92  					res.errChannel <- err
    93  					return err
    94  				}
    95  			}
    96  			res.queryID = respd.Data.QueryID
    97  			res.errChannel <- nil // mark exec status complete
    98  		} else {
    99  			rows.sc = sc
   100  			rows.queryID = respd.Data.QueryID
   101  			if isMultiStmt(&respd.Data) {
   102  				if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil {
   103  					rows.errChannel <- err
   104  					close(rows.errChannel)
   105  					return err
   106  				}
   107  			} else {
   108  				rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
   109  			}
   110  			if err = rows.ChunkDownloader.start(); err != nil {
   111  				rows.errChannel <- err
   112  				close(rows.errChannel)
   113  				return err
   114  			}
   115  			rows.errChannel <- nil // mark query status complete
   116  		}
   117  	} else {
   118  		var code int
   119  		if respd.Code != "" {
   120  			code, err = strconv.Atoi(respd.Code)
   121  			if err != nil {
   122  				code = -1
   123  			}
   124  		} else {
   125  			code = -1
   126  		}
   127  		errChannel <- &SnowflakeError{
   128  			Number:   code,
   129  			SQLState: respd.Data.SQLState,
   130  			Message:  respd.Message,
   131  			QueryID:  respd.Data.QueryID,
   132  		}
   133  	}
   134  	return nil
   135  }
   136  
   137  func getQueryResultWithRetriesForAsyncMode(
   138  	ctx context.Context,
   139  	sr *snowflakeRestful,
   140  	URL *url.URL,
   141  	headers map[string]string,
   142  	timeout time.Duration) (*execResponse, error) {
   143  	var respd *execResponse
   144  	retry := 0
   145  	retryPattern := []int32{1, 1, 2, 3, 4, 8, 10}
   146  	retryPatternIndex := 0
   147  
   148  	for {
   149  		logger.WithContext(ctx).Debugf("Retry count for get query result request in async mode: %v", retry)
   150  
   151  		resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
   152  		if err != nil {
   153  			logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
   154  			return respd, err
   155  		}
   156  		defer resp.Body.Close()
   157  
   158  		respd = &execResponse{} // reset the response
   159  		err = json.NewDecoder(resp.Body).Decode(&respd)
   160  		if err != nil {
   161  			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
   162  			return respd, err
   163  		}
   164  		if respd.Code != queryInProgressAsyncCode {
   165  			// If the query takes longer than 45 seconds to complete the results are not returned.
   166  			// If the query is still in progress after 45 seconds, retry the request to the /results endpoint.
   167  			// For all other scenarios continue processing results response
   168  			break
   169  		} else {
   170  			// Sleep before retrying get result request. Exponential backoff up to 5 seconds.
   171  			// Once 5 second backoff is reached it will keep retrying with this sleeptime.
   172  			sleepTime := time.Millisecond * time.Duration(500*retryPattern[retryPatternIndex])
   173  			logger.WithContext(ctx).Infof("Query execution still in progress. Response code: %v, message: %v Sleep for %v ms", respd.Code, respd.Message, sleepTime)
   174  			time.Sleep(sleepTime)
   175  			retry++
   176  
   177  			if retryPatternIndex < len(retryPattern)-1 {
   178  				retryPatternIndex++
   179  			}
   180  		}
   181  	}
   182  	return respd, nil
   183  }