github.com/snowflakedb/gosnowflake@v1.9.0/bind_uploader.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "database/sql" 9 "database/sql/driver" 10 "fmt" 11 "reflect" 12 "strconv" 13 "strings" 14 ) 15 16 const ( 17 bindStageName = "SYSTEM$BIND" 18 createTemporaryStageStmt = "CREATE OR REPLACE TEMPORARY STAGE " + bindStageName + 19 " file_format=" + "(type=csv field_optionally_enclosed_by='\"')" 20 21 // size (in bytes) of max input stream (10MB default) as per JDBC specs 22 inputStreamBufferSize = 1024 * 1024 * 10 23 ) 24 25 type bindUploader struct { 26 ctx context.Context 27 sc *snowflakeConn 28 stagePath string 29 fileCount int 30 arrayBindStage string 31 } 32 33 func (bu *bindUploader) upload(bindings []driver.NamedValue) (*execResponse, error) { 34 bindingRows, err := bu.buildRowsAsBytes(bindings) 35 if err != nil { 36 return nil, err 37 } 38 startIdx, numBytes, rowNum := 0, 0, 0 39 bu.fileCount = 0 40 var data *execResponse 41 for rowNum < len(bindingRows) { 42 for numBytes < inputStreamBufferSize && rowNum < len(bindingRows) { 43 numBytes += len(bindingRows[rowNum]) 44 rowNum++ 45 } 46 // concatenate all byte arrays into 1 and put into input stream 47 var b bytes.Buffer 48 b.Grow(numBytes) 49 for i := startIdx; i < rowNum; i++ { 50 b.Write(bindingRows[i]) 51 } 52 53 bu.fileCount++ 54 data, err = bu.uploadStreamInternal(&b, bu.fileCount, true) 55 if err != nil { 56 return nil, err 57 } 58 startIdx = rowNum 59 numBytes = 0 60 } 61 return data, nil 62 } 63 64 func (bu *bindUploader) uploadStreamInternal( 65 inputStream *bytes.Buffer, 66 dstFileName int, 67 compressData bool) ( 68 *execResponse, error) { 69 if err := bu.createStageIfNeeded(); err != nil { 70 return nil, err 71 } 72 stageName := bu.stagePath 73 if stageName == "" { 74 return nil, (&SnowflakeError{ 75 Number: ErrBindUpload, 76 Message: "stage name is null", 77 }).exceptionTelemetry(bu.sc) 78 } 79 80 // use a placeholder for source file 81 putCommand := fmt.Sprintf("put 'file:///tmp/placeholder/%v' '%v' overwrite=true", dstFileName, stageName) 82 // for Windows queries 83 putCommand = strings.ReplaceAll(putCommand, "\\", "\\\\") 84 // prepare context for PUT command 85 ctx := WithFileStream(bu.ctx, inputStream) 86 ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{ 87 compressSourceFromStream: compressData}) 88 return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{}) 89 } 90 91 func (bu *bindUploader) createStageIfNeeded() error { 92 if bu.arrayBindStage != "" { 93 return nil 94 } 95 data, err := bu.sc.exec(bu.ctx, createTemporaryStageStmt, false, false, false, []driver.NamedValue{}) 96 if err != nil { 97 newThreshold := "0" 98 bu.sc.cfg.Params[sessionArrayBindStageThreshold] = &newThreshold 99 return err 100 } 101 if !data.Success { 102 code, err := strconv.Atoi(data.Code) 103 if err != nil { 104 return err 105 } 106 return (&SnowflakeError{ 107 Number: code, 108 SQLState: data.Data.SQLState, 109 Message: err.Error(), 110 QueryID: data.Data.QueryID, 111 }).exceptionTelemetry(bu.sc) 112 } 113 bu.arrayBindStage = bindStageName 114 return nil 115 } 116 117 // transpose the columns to rows and write them to a list of bytes 118 func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte, error) { 119 numColumns := len(columns) 120 if columns[0].Value == nil { 121 return nil, (&SnowflakeError{ 122 Number: ErrBindSerialization, 123 Message: "no binds found in the first column", 124 }).exceptionTelemetry(bu.sc) 125 } 126 127 _, column := snowflakeArrayToString(&columns[0], true) 128 numRows := len(column) 129 csvRows := make([][]byte, 0) 130 rows := make([][]interface{}, 0) 131 for rowIdx := 0; rowIdx < numRows; rowIdx++ { 132 rows = append(rows, make([]interface{}, numColumns)) 133 } 134 135 for rowIdx := 0; rowIdx < numRows; rowIdx++ { 136 if column[rowIdx] == nil { 137 rows[rowIdx][0] = column[rowIdx] 138 } else { 139 rows[rowIdx][0] = *column[rowIdx] 140 } 141 } 142 for colIdx := 1; colIdx < numColumns; colIdx++ { 143 _, column = snowflakeArrayToString(&columns[colIdx], true) 144 iNumRows := len(column) 145 if iNumRows != numRows { 146 return nil, (&SnowflakeError{ 147 Number: ErrBindSerialization, 148 Message: errMsgBindColumnMismatch, 149 MessageArgs: []interface{}{colIdx, iNumRows, numRows}, 150 }).exceptionTelemetry(bu.sc) 151 } 152 for rowIdx := 0; rowIdx < numRows; rowIdx++ { 153 // length of column = number of rows 154 if column[rowIdx] == nil { 155 rows[rowIdx][colIdx] = column[rowIdx] 156 } else { 157 rows[rowIdx][colIdx] = *column[rowIdx] 158 } 159 } 160 } 161 for _, row := range rows { 162 csvRows = append(csvRows, bu.createCSVRecord(row)) 163 } 164 return csvRows, nil 165 } 166 167 func (bu *bindUploader) createCSVRecord(data []interface{}) []byte { 168 var b strings.Builder 169 b.Grow(1024) 170 for i := 0; i < len(data); i++ { 171 if i > 0 { 172 b.WriteString(",") 173 } 174 value, ok := data[i].(string) 175 if ok { 176 b.WriteString(escapeForCSV(value)) 177 } else if !reflect.ValueOf(data[i]).IsNil() { 178 logger.Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i]) 179 } 180 } 181 b.WriteString("\n") 182 return []byte(b.String()) 183 } 184 185 func (sc *snowflakeConn) processBindings( 186 ctx context.Context, 187 bindings []driver.NamedValue, 188 describeOnly bool, 189 requestID UUID, 190 req *execRequest) error { 191 arrayBindThreshold := sc.getArrayBindStageThreshold() 192 numBinds := arrayBindValueCount(bindings) 193 if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) { 194 uploader := bindUploader{ 195 sc: sc, 196 ctx: ctx, 197 stagePath: "@" + bindStageName + "/" + requestID.String(), 198 } 199 _, err := uploader.upload(bindings) 200 if err != nil { 201 return err 202 } 203 req.Bindings = nil 204 req.BindStage = uploader.stagePath 205 } else { 206 var err error 207 req.Bindings, err = getBindValues(bindings) 208 if err != nil { 209 return err 210 } 211 req.BindStage = "" 212 } 213 return nil 214 } 215 216 func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, error) { 217 tsmode := timestampNtzType 218 idx := 1 219 var err error 220 bindValues := make(map[string]execBindParameter, len(bindings)) 221 for _, binding := range bindings { 222 if tnt, ok := binding.Value.(TypedNullTime); ok { 223 tsmode = convertTzTypeToSnowflakeType(tnt.TzType) 224 binding.Value = tnt.Time 225 } 226 t := goTypeToSnowflake(binding.Value, tsmode) 227 if t == changeType { 228 tsmode, err = dataTypeMode(binding.Value) 229 if err != nil { 230 return nil, err 231 } 232 } else { 233 var val interface{} 234 if t == sliceType { 235 // retrieve array binding data 236 t, val = snowflakeArrayToString(&binding, false) 237 } else { 238 val, err = valueToString(binding.Value, tsmode) 239 if err != nil { 240 return nil, err 241 } 242 } 243 if t == nullType || t == unSupportedType { 244 t = textType // if null or not supported, pass to GS as text 245 } 246 bindValues[bindingName(binding, idx)] = execBindParameter{ 247 Type: t.String(), 248 Value: val, 249 } 250 idx++ 251 } 252 } 253 return bindValues, nil 254 } 255 256 func bindingName(nv driver.NamedValue, idx int) string { 257 if nv.Name != "" { 258 return nv.Name 259 } 260 return strconv.Itoa(idx) 261 } 262 263 func arrayBindValueCount(bindValues []driver.NamedValue) int { 264 if !isArrayBind(bindValues) { 265 return 0 266 } 267 _, arr := snowflakeArrayToString(&bindValues[0], false) 268 return len(bindValues) * len(arr) 269 } 270 271 func isArrayBind(bindings []driver.NamedValue) bool { 272 if len(bindings) == 0 { 273 return false 274 } 275 for _, binding := range bindings { 276 if supported := supportedArrayBind(&binding); !supported { 277 return false 278 } 279 } 280 return true 281 } 282 283 func supportedArrayBind(nv *driver.NamedValue) bool { 284 switch reflect.TypeOf(nv.Value) { 285 case reflect.TypeOf(&intArray{}), reflect.TypeOf(&int32Array{}), 286 reflect.TypeOf(&int64Array{}), reflect.TypeOf(&float64Array{}), 287 reflect.TypeOf(&float32Array{}), reflect.TypeOf(&boolArray{}), 288 reflect.TypeOf(&stringArray{}), reflect.TypeOf(&byteArray{}), 289 reflect.TypeOf(×tampNtzArray{}), reflect.TypeOf(×tampLtzArray{}), 290 reflect.TypeOf(×tampTzArray{}), reflect.TypeOf(&dateArray{}), 291 reflect.TypeOf(&timeArray{}): 292 return true 293 case reflect.TypeOf([]uint8{}): 294 // internal binding ts mode 295 val, ok := nv.Value.([]uint8) 296 if !ok { 297 return ok 298 } 299 if len(val) == 0 { 300 return true // for null binds 301 } 302 if fixedType <= snowflakeType(val[0]) && snowflakeType(val[0]) <= unSupportedType { 303 return true 304 } 305 return false 306 default: 307 // TODO SNOW-176486 variant, object, array 308 309 // Support for bulk array binding insertion using []interface{} 310 if isInterfaceArrayBinding(nv.Value) { 311 return true 312 } 313 return false 314 } 315 } 316 317 func supportedNullBind(nv *driver.NamedValue) bool { 318 switch reflect.TypeOf(nv.Value) { 319 case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}), 320 reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}), reflect.TypeOf(TypedNullTime{}): 321 return true 322 } 323 return false 324 }