github.com/snowflakedb/gosnowflake@v1.9.0/connection_util.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "fmt" 9 "io" 10 "os" 11 "strconv" 12 "strings" 13 "time" 14 ) 15 16 func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { 17 paramsMutex.Lock() 18 v, ok := sc.cfg.Params[sessionClientSessionKeepAlive] 19 paramsMutex.Unlock() 20 if !ok { 21 return false 22 } 23 return strings.Compare(*v, "true") == 0 24 } 25 26 func (sc *snowflakeConn) startHeartBeat() { 27 if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { 28 return 29 } 30 if sc.rest != nil { 31 sc.rest.HeartBeat = &heartbeat{ 32 restful: sc.rest, 33 } 34 sc.rest.HeartBeat.start() 35 } 36 } 37 38 func (sc *snowflakeConn) stopHeartBeat() { 39 if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { 40 return 41 } 42 if sc.rest != nil && sc.rest.HeartBeat != nil { 43 sc.rest.HeartBeat.stop() 44 } 45 } 46 47 func (sc *snowflakeConn) getArrayBindStageThreshold() int { 48 paramsMutex.Lock() 49 v, ok := sc.cfg.Params[sessionArrayBindStageThreshold] 50 paramsMutex.Unlock() 51 if !ok { 52 return 0 53 } 54 num, err := strconv.Atoi(*v) 55 if err != nil { 56 return 0 57 } 58 return num 59 } 60 61 func (sc *snowflakeConn) connectionTelemetry(cfg *Config) { 62 data := &telemetryData{ 63 Message: map[string]string{ 64 typeKey: connectionParameters, 65 sourceKey: telemetrySource, 66 driverTypeKey: "Go", 67 driverVersionKey: SnowflakeGoDriverVersion, 68 }, 69 Timestamp: time.Now().UnixNano() / int64(time.Millisecond), 70 } 71 paramsMutex.Lock() 72 for k, v := range cfg.Params { 73 data.Message[k] = *v 74 } 75 paramsMutex.Unlock() 76 sc.telemetry.addLog(data) 77 sc.telemetry.sendBatch() 78 } 79 80 // processFileTransfer creates a snowflakeFileTransferAgent object to process 81 // any PUT/GET commands with their specified options 82 func (sc *snowflakeConn) processFileTransfer( 83 ctx context.Context, 84 data *execResponse, 85 query string, 86 isInternal bool) ( 87 *execResponse, error) { 88 sfa := snowflakeFileTransferAgent{ 89 sc: sc, 90 data: &data.Data, 91 command: query, 92 options: new(SnowflakeFileTransferOptions), 93 } 94 if fs := getFileStream(ctx); fs != nil { 95 sfa.sourceStream = fs 96 if isInternal { 97 sfa.data.AutoCompress = false 98 } 99 } 100 if op := getFileTransferOptions(ctx); op != nil { 101 sfa.options = op 102 } 103 if sfa.options.MultiPartThreshold == 0 { 104 sfa.options.MultiPartThreshold = dataSizeThreshold 105 } 106 if err := sfa.execute(); err != nil { 107 return nil, err 108 } 109 data, err := sfa.result() 110 if err != nil { 111 return nil, err 112 } 113 return data, nil 114 } 115 116 func getFileStream(ctx context.Context) *bytes.Buffer { 117 s := ctx.Value(fileStreamFile) 118 r, ok := s.(io.Reader) 119 if !ok { 120 return nil 121 } 122 buf := new(bytes.Buffer) 123 buf.ReadFrom(r) 124 return buf 125 } 126 127 func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions { 128 v := ctx.Value(fileTransferOptions) 129 if v == nil { 130 return nil 131 } 132 o, ok := v.(*SnowflakeFileTransferOptions) 133 if !ok { 134 return nil 135 } 136 return o 137 } 138 139 func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) { 140 // other session parameters (not all) 141 logger.WithContext(sc.ctx).Infof("params: %#v", parameters) 142 for _, param := range parameters { 143 v := "" 144 switch param.Value.(type) { 145 case int64: 146 if vv, ok := param.Value.(int64); ok { 147 v = strconv.FormatInt(vv, 10) 148 } 149 case float64: 150 if vv, ok := param.Value.(float64); ok { 151 v = strconv.FormatFloat(vv, 'g', -1, 64) 152 } 153 case bool: 154 if vv, ok := param.Value.(bool); ok { 155 v = strconv.FormatBool(vv) 156 } 157 default: 158 if vv, ok := param.Value.(string); ok { 159 v = vv 160 } 161 } 162 logger.Debugf("parameter. name: %v, value: %v", param.Name, v) 163 paramsMutex.Lock() 164 sc.cfg.Params[strings.ToLower(param.Name)] = &v 165 paramsMutex.Unlock() 166 } 167 } 168 169 func isAsyncMode(ctx context.Context) bool { 170 val := ctx.Value(asyncMode) 171 if val == nil { 172 return false 173 } 174 a, ok := val.(bool) 175 return ok && a 176 } 177 178 func isDescribeOnly(ctx context.Context) bool { 179 v := ctx.Value(describeOnly) 180 if v == nil { 181 return false 182 } 183 d, ok := v.(bool) 184 return ok && d 185 } 186 187 func setResultType(ctx context.Context, resType resultType) context.Context { 188 return context.WithValue(ctx, snowflakeResultType, resType) 189 } 190 191 func getResultType(ctx context.Context) resultType { 192 return ctx.Value(snowflakeResultType).(resultType) 193 } 194 195 // isDml returns true if the statement type code is in the range of DML. 196 func isDml(v int64) bool { 197 return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert 198 } 199 200 func isDql(data *execResponseData) bool { 201 return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) 202 } 203 204 func updateRows(data execResponseData) (int64, error) { 205 if data.RowSet == nil { 206 return 0, nil 207 } 208 var count int64 209 for i, n := 0, len(data.RowType); i < n; i++ { 210 v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) 211 if err != nil { 212 return -1, err 213 } 214 count += v 215 } 216 return count, nil 217 } 218 219 // isMultiStmt returns true if the statement code is of type multistatement 220 // Note that the statement type code is also equivalent to type INSERT, so an 221 // additional check of the name is required 222 func isMultiStmt(data *execResponseData) bool { 223 var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution" 224 return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement 225 } 226 227 func getResumeQueryID(ctx context.Context) (string, error) { 228 val := ctx.Value(fetchResultByID) 229 if val == nil { 230 return "", nil 231 } 232 strVal, ok := val.(string) 233 if !ok { 234 return "", fmt.Errorf("failed to cast val %+v to string", val) 235 } 236 // so there is a queryID in context for which we want to fetch the result 237 if !queryIDRegexp.MatchString(strVal) { 238 return strVal, &SnowflakeError{ 239 Number: ErrQueryIDFormat, 240 Message: "Invalid QID", 241 QueryID: strVal, 242 } 243 } 244 return strVal, nil 245 } 246 247 // returns snowflake chunk downloader by default or stream based chunk 248 // downloader if option provided through context 249 func populateChunkDownloader( 250 ctx context.Context, 251 sc *snowflakeConn, 252 data execResponseData) chunkDownloader { 253 if useStreamDownloader(ctx) && resultFormat(data.QueryResultFormat) == jsonFormat { 254 // stream chunk downloading only works for row based data formats, i.e. json 255 fetcher := &httpStreamChunkFetcher{ 256 ctx: ctx, 257 client: sc.rest.Client, 258 clientIP: sc.cfg.ClientIP, 259 headers: data.ChunkHeaders, 260 qrmk: data.Qrmk, 261 } 262 return newStreamChunkDownloader(ctx, fetcher, data.Total, data.RowType, 263 data.RowSet, data.Chunks) 264 } 265 266 return &snowflakeChunkDownloader{ 267 sc: sc, 268 ctx: ctx, 269 pool: getAllocator(ctx), 270 CurrentChunk: make([]chunkRowType, len(data.RowSet)), 271 ChunkMetas: data.Chunks, 272 Total: data.Total, 273 TotalRowIndex: int64(-1), 274 CellCount: len(data.RowType), 275 Qrmk: data.Qrmk, 276 QueryResultFormat: data.QueryResultFormat, 277 ChunkHeader: data.ChunkHeaders, 278 FuncDownload: downloadChunk, 279 FuncDownloadHelper: downloadChunkHelper, 280 FuncGet: getChunk, 281 RowSet: rowSetType{ 282 RowType: data.RowType, 283 JSON: data.RowSet, 284 RowSetBase64: data.RowSetBase64, 285 }, 286 } 287 } 288 289 func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { 290 ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host) 291 logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) 292 if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil { 293 return err 294 } 295 ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" 296 logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) 297 if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil { 298 return err 299 } 300 return nil 301 } 302 303 func isStatementContext(ctx context.Context) bool { 304 v := ctx.Value(executionType) 305 return v == executionTypeStatement 306 }