github.com/snowflakedb/gosnowflake@v1.9.0/bind_uploader.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"fmt"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  )
    15  
    16  const (
    17  	bindStageName            = "SYSTEM$BIND"
    18  	createTemporaryStageStmt = "CREATE OR REPLACE TEMPORARY STAGE " + bindStageName +
    19  		" file_format=" + "(type=csv field_optionally_enclosed_by='\"')"
    20  
    21  	// size (in bytes) of max input stream (10MB default) as per JDBC specs
    22  	inputStreamBufferSize = 1024 * 1024 * 10
    23  )
    24  
    25  type bindUploader struct {
    26  	ctx            context.Context
    27  	sc             *snowflakeConn
    28  	stagePath      string
    29  	fileCount      int
    30  	arrayBindStage string
    31  }
    32  
    33  func (bu *bindUploader) upload(bindings []driver.NamedValue) (*execResponse, error) {
    34  	bindingRows, err := bu.buildRowsAsBytes(bindings)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	startIdx, numBytes, rowNum := 0, 0, 0
    39  	bu.fileCount = 0
    40  	var data *execResponse
    41  	for rowNum < len(bindingRows) {
    42  		for numBytes < inputStreamBufferSize && rowNum < len(bindingRows) {
    43  			numBytes += len(bindingRows[rowNum])
    44  			rowNum++
    45  		}
    46  		// concatenate all byte arrays into 1 and put into input stream
    47  		var b bytes.Buffer
    48  		b.Grow(numBytes)
    49  		for i := startIdx; i < rowNum; i++ {
    50  			b.Write(bindingRows[i])
    51  		}
    52  
    53  		bu.fileCount++
    54  		data, err = bu.uploadStreamInternal(&b, bu.fileCount, true)
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  		startIdx = rowNum
    59  		numBytes = 0
    60  	}
    61  	return data, nil
    62  }
    63  
    64  func (bu *bindUploader) uploadStreamInternal(
    65  	inputStream *bytes.Buffer,
    66  	dstFileName int,
    67  	compressData bool) (
    68  	*execResponse, error) {
    69  	if err := bu.createStageIfNeeded(); err != nil {
    70  		return nil, err
    71  	}
    72  	stageName := bu.stagePath
    73  	if stageName == "" {
    74  		return nil, (&SnowflakeError{
    75  			Number:  ErrBindUpload,
    76  			Message: "stage name is null",
    77  		}).exceptionTelemetry(bu.sc)
    78  	}
    79  
    80  	// use a placeholder for source file
    81  	putCommand := fmt.Sprintf("put 'file:///tmp/placeholder/%v' '%v' overwrite=true", dstFileName, stageName)
    82  	// for Windows queries
    83  	putCommand = strings.ReplaceAll(putCommand, "\\", "\\\\")
    84  	// prepare context for PUT command
    85  	ctx := WithFileStream(bu.ctx, inputStream)
    86  	ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{
    87  		compressSourceFromStream: compressData})
    88  	return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{})
    89  }
    90  
    91  func (bu *bindUploader) createStageIfNeeded() error {
    92  	if bu.arrayBindStage != "" {
    93  		return nil
    94  	}
    95  	data, err := bu.sc.exec(bu.ctx, createTemporaryStageStmt, false, false, false, []driver.NamedValue{})
    96  	if err != nil {
    97  		newThreshold := "0"
    98  		bu.sc.cfg.Params[sessionArrayBindStageThreshold] = &newThreshold
    99  		return err
   100  	}
   101  	if !data.Success {
   102  		code, err := strconv.Atoi(data.Code)
   103  		if err != nil {
   104  			return err
   105  		}
   106  		return (&SnowflakeError{
   107  			Number:   code,
   108  			SQLState: data.Data.SQLState,
   109  			Message:  err.Error(),
   110  			QueryID:  data.Data.QueryID,
   111  		}).exceptionTelemetry(bu.sc)
   112  	}
   113  	bu.arrayBindStage = bindStageName
   114  	return nil
   115  }
   116  
   117  // transpose the columns to rows and write them to a list of bytes
   118  func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte, error) {
   119  	numColumns := len(columns)
   120  	if columns[0].Value == nil {
   121  		return nil, (&SnowflakeError{
   122  			Number:  ErrBindSerialization,
   123  			Message: "no binds found in the first column",
   124  		}).exceptionTelemetry(bu.sc)
   125  	}
   126  
   127  	_, column := snowflakeArrayToString(&columns[0], true)
   128  	numRows := len(column)
   129  	csvRows := make([][]byte, 0)
   130  	rows := make([][]interface{}, 0)
   131  	for rowIdx := 0; rowIdx < numRows; rowIdx++ {
   132  		rows = append(rows, make([]interface{}, numColumns))
   133  	}
   134  
   135  	for rowIdx := 0; rowIdx < numRows; rowIdx++ {
   136  		if column[rowIdx] == nil {
   137  			rows[rowIdx][0] = column[rowIdx]
   138  		} else {
   139  			rows[rowIdx][0] = *column[rowIdx]
   140  		}
   141  	}
   142  	for colIdx := 1; colIdx < numColumns; colIdx++ {
   143  		_, column = snowflakeArrayToString(&columns[colIdx], true)
   144  		iNumRows := len(column)
   145  		if iNumRows != numRows {
   146  			return nil, (&SnowflakeError{
   147  				Number:      ErrBindSerialization,
   148  				Message:     errMsgBindColumnMismatch,
   149  				MessageArgs: []interface{}{colIdx, iNumRows, numRows},
   150  			}).exceptionTelemetry(bu.sc)
   151  		}
   152  		for rowIdx := 0; rowIdx < numRows; rowIdx++ {
   153  			// length of column = number of rows
   154  			if column[rowIdx] == nil {
   155  				rows[rowIdx][colIdx] = column[rowIdx]
   156  			} else {
   157  				rows[rowIdx][colIdx] = *column[rowIdx]
   158  			}
   159  		}
   160  	}
   161  	for _, row := range rows {
   162  		csvRows = append(csvRows, bu.createCSVRecord(row))
   163  	}
   164  	return csvRows, nil
   165  }
   166  
   167  func (bu *bindUploader) createCSVRecord(data []interface{}) []byte {
   168  	var b strings.Builder
   169  	b.Grow(1024)
   170  	for i := 0; i < len(data); i++ {
   171  		if i > 0 {
   172  			b.WriteString(",")
   173  		}
   174  		value, ok := data[i].(string)
   175  		if ok {
   176  			b.WriteString(escapeForCSV(value))
   177  		} else if !reflect.ValueOf(data[i]).IsNil() {
   178  			logger.Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i])
   179  		}
   180  	}
   181  	b.WriteString("\n")
   182  	return []byte(b.String())
   183  }
   184  
   185  func (sc *snowflakeConn) processBindings(
   186  	ctx context.Context,
   187  	bindings []driver.NamedValue,
   188  	describeOnly bool,
   189  	requestID UUID,
   190  	req *execRequest) error {
   191  	arrayBindThreshold := sc.getArrayBindStageThreshold()
   192  	numBinds := arrayBindValueCount(bindings)
   193  	if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) {
   194  		uploader := bindUploader{
   195  			sc:        sc,
   196  			ctx:       ctx,
   197  			stagePath: "@" + bindStageName + "/" + requestID.String(),
   198  		}
   199  		_, err := uploader.upload(bindings)
   200  		if err != nil {
   201  			return err
   202  		}
   203  		req.Bindings = nil
   204  		req.BindStage = uploader.stagePath
   205  	} else {
   206  		var err error
   207  		req.Bindings, err = getBindValues(bindings)
   208  		if err != nil {
   209  			return err
   210  		}
   211  		req.BindStage = ""
   212  	}
   213  	return nil
   214  }
   215  
   216  func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, error) {
   217  	tsmode := timestampNtzType
   218  	idx := 1
   219  	var err error
   220  	bindValues := make(map[string]execBindParameter, len(bindings))
   221  	for _, binding := range bindings {
   222  		if tnt, ok := binding.Value.(TypedNullTime); ok {
   223  			tsmode = convertTzTypeToSnowflakeType(tnt.TzType)
   224  			binding.Value = tnt.Time
   225  		}
   226  		t := goTypeToSnowflake(binding.Value, tsmode)
   227  		if t == changeType {
   228  			tsmode, err = dataTypeMode(binding.Value)
   229  			if err != nil {
   230  				return nil, err
   231  			}
   232  		} else {
   233  			var val interface{}
   234  			if t == sliceType {
   235  				// retrieve array binding data
   236  				t, val = snowflakeArrayToString(&binding, false)
   237  			} else {
   238  				val, err = valueToString(binding.Value, tsmode)
   239  				if err != nil {
   240  					return nil, err
   241  				}
   242  			}
   243  			if t == nullType || t == unSupportedType {
   244  				t = textType // if null or not supported, pass to GS as text
   245  			}
   246  			bindValues[bindingName(binding, idx)] = execBindParameter{
   247  				Type:  t.String(),
   248  				Value: val,
   249  			}
   250  			idx++
   251  		}
   252  	}
   253  	return bindValues, nil
   254  }
   255  
   256  func bindingName(nv driver.NamedValue, idx int) string {
   257  	if nv.Name != "" {
   258  		return nv.Name
   259  	}
   260  	return strconv.Itoa(idx)
   261  }
   262  
   263  func arrayBindValueCount(bindValues []driver.NamedValue) int {
   264  	if !isArrayBind(bindValues) {
   265  		return 0
   266  	}
   267  	_, arr := snowflakeArrayToString(&bindValues[0], false)
   268  	return len(bindValues) * len(arr)
   269  }
   270  
   271  func isArrayBind(bindings []driver.NamedValue) bool {
   272  	if len(bindings) == 0 {
   273  		return false
   274  	}
   275  	for _, binding := range bindings {
   276  		if supported := supportedArrayBind(&binding); !supported {
   277  			return false
   278  		}
   279  	}
   280  	return true
   281  }
   282  
   283  func supportedArrayBind(nv *driver.NamedValue) bool {
   284  	switch reflect.TypeOf(nv.Value) {
   285  	case reflect.TypeOf(&intArray{}), reflect.TypeOf(&int32Array{}),
   286  		reflect.TypeOf(&int64Array{}), reflect.TypeOf(&float64Array{}),
   287  		reflect.TypeOf(&float32Array{}), reflect.TypeOf(&boolArray{}),
   288  		reflect.TypeOf(&stringArray{}), reflect.TypeOf(&byteArray{}),
   289  		reflect.TypeOf(&timestampNtzArray{}), reflect.TypeOf(&timestampLtzArray{}),
   290  		reflect.TypeOf(&timestampTzArray{}), reflect.TypeOf(&dateArray{}),
   291  		reflect.TypeOf(&timeArray{}):
   292  		return true
   293  	case reflect.TypeOf([]uint8{}):
   294  		// internal binding ts mode
   295  		val, ok := nv.Value.([]uint8)
   296  		if !ok {
   297  			return ok
   298  		}
   299  		if len(val) == 0 {
   300  			return true // for null binds
   301  		}
   302  		if fixedType <= snowflakeType(val[0]) && snowflakeType(val[0]) <= unSupportedType {
   303  			return true
   304  		}
   305  		return false
   306  	default:
   307  		// TODO SNOW-176486 variant, object, array
   308  
   309  		// Support for bulk array binding insertion using []interface{}
   310  		if isInterfaceArrayBinding(nv.Value) {
   311  			return true
   312  		}
   313  		return false
   314  	}
   315  }
   316  
   317  func supportedNullBind(nv *driver.NamedValue) bool {
   318  	switch reflect.TypeOf(nv.Value) {
   319  	case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}),
   320  		reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}), reflect.TypeOf(TypedNullTime{}):
   321  		return true
   322  	}
   323  	return false
   324  }