github.com/snowflakedb/gosnowflake@v1.9.0/multistatement.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql/driver"
     8  	"fmt"
     9  	"strconv"
    10  	"strings"
    11  )
    12  
    13  type childResult struct {
    14  	id  string
    15  	typ string
    16  }
    17  
    18  func getChildResults(IDs string, types string) []childResult {
    19  	if IDs == "" {
    20  		return nil
    21  	}
    22  	queryIDs := strings.Split(IDs, ",")
    23  	resultTypes := strings.Split(types, ",")
    24  	res := make([]childResult, len(queryIDs))
    25  	for i, id := range queryIDs {
    26  		res[i] = childResult{id, resultTypes[i]}
    27  	}
    28  	return res
    29  }
    30  
    31  func (sc *snowflakeConn) handleMultiExec(
    32  	ctx context.Context,
    33  	data execResponseData) (
    34  	driver.Result, error) {
    35  	if data.ResultIDs == "" {
    36  		return nil, (&SnowflakeError{
    37  			Number:   ErrNoResultIDs,
    38  			SQLState: data.SQLState,
    39  			Message:  errMsgNoResultIDs,
    40  			QueryID:  data.QueryID,
    41  		}).exceptionTelemetry(sc)
    42  	}
    43  	var updatedRows int64
    44  	childResults := getChildResults(data.ResultIDs, data.ResultTypes)
    45  	for _, child := range childResults {
    46  		resultPath := fmt.Sprintf(urlQueriesResultFmt, child.id)
    47  		childResultType, err := strconv.ParseInt(child.typ, 10, 64)
    48  		if err != nil {
    49  			return nil, err
    50  		}
    51  		if isDml(childResultType) {
    52  			childData, err := sc.getQueryResultResp(ctx, resultPath)
    53  			if err != nil {
    54  				logger.Errorf("error: %v", err)
    55  				return nil, err
    56  			}
    57  			if childData != nil && !childData.Success {
    58  				code, err := strconv.Atoi(childData.Code)
    59  				if err != nil {
    60  					return nil, err
    61  				}
    62  				return nil, (&SnowflakeError{
    63  					Number:   code,
    64  					SQLState: childData.Data.SQLState,
    65  					Message:  childData.Message,
    66  					QueryID:  childData.Data.QueryID,
    67  				}).exceptionTelemetry(sc)
    68  			}
    69  			count, err := updateRows(childData.Data)
    70  			if err != nil {
    71  				logger.WithContext(ctx).Errorf("error: %v", err)
    72  				return nil, err
    73  			}
    74  			updatedRows += count
    75  		}
    76  	}
    77  	logger.WithContext(ctx).Infof("number of updated rows: %#v", updatedRows)
    78  	return &snowflakeResult{
    79  		affectedRows: updatedRows,
    80  		insertID:     -1,
    81  		queryID:      data.QueryID,
    82  	}, nil
    83  }
    84  
    85  // Fill the corresponding rows and add chunk downloader into the rows when
    86  // iterating across the childResults
    87  func (sc *snowflakeConn) handleMultiQuery(
    88  	ctx context.Context,
    89  	data execResponseData,
    90  	rows *snowflakeRows) error {
    91  	if data.ResultIDs == "" {
    92  		return (&SnowflakeError{
    93  			Number:   ErrNoResultIDs,
    94  			SQLState: data.SQLState,
    95  			Message:  errMsgNoResultIDs,
    96  			QueryID:  data.QueryID,
    97  		}).exceptionTelemetry(sc)
    98  	}
    99  	childResults := getChildResults(data.ResultIDs, data.ResultTypes)
   100  	for _, child := range childResults {
   101  		if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil {
   102  			return err
   103  		}
   104  	}
   105  	return nil
   106  }