github.com/snowflakedb/gosnowflake@v1.9.0/statement.go (about)

     1  // Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql/driver"
     8  	"errors"
     9  	"fmt"
    10  )
    11  
    12  // SnowflakeStmt represents the prepared statement in driver.
    13  type SnowflakeStmt interface {
    14  	GetQueryID() string
    15  }
    16  
    17  type snowflakeStmt struct {
    18  	sc          *snowflakeConn
    19  	query       string
    20  	lastQueryID string
    21  }
    22  
    23  func (stmt *snowflakeStmt) Close() error {
    24  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Close")
    25  	// noop
    26  	return nil
    27  }
    28  
    29  func (stmt *snowflakeStmt) NumInput() int {
    30  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.NumInput")
    31  	// Go Snowflake doesn't know the number of binding parameters.
    32  	return -1
    33  }
    34  
    35  func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    36  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
    37  	return stmt.execInternal(ctx, args)
    38  }
    39  
    40  func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    41  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext")
    42  	rows, err := stmt.sc.QueryContext(ctx, stmt.query, args)
    43  	if err != nil {
    44  		stmt.setQueryIDFromError(err)
    45  		return nil, err
    46  	}
    47  	r, ok := rows.(SnowflakeRows)
    48  	if !ok {
    49  		return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows)
    50  	}
    51  	stmt.lastQueryID = r.GetQueryID()
    52  	return rows, nil
    53  }
    54  
    55  func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
    56  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
    57  	return stmt.execInternal(context.Background(), toNamedValues(args))
    58  }
    59  
    60  func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    61  	logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal")
    62  	if ctx == nil {
    63  		ctx = context.Background()
    64  	}
    65  	stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement)
    66  	result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args)
    67  	if err != nil {
    68  		stmt.setQueryIDFromError(err)
    69  		return nil, err
    70  	}
    71  	rnr, ok := result.(*snowflakeResultNoRows)
    72  	if ok {
    73  		stmt.lastQueryID = rnr.GetQueryID()
    74  		return driver.ResultNoRows, nil
    75  	}
    76  	r, ok := result.(SnowflakeResult)
    77  	if !ok {
    78  		return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
    79  	}
    80  	stmt.lastQueryID = r.GetQueryID()
    81  	return result, err
    82  }
    83  
    84  func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) {
    85  	logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query")
    86  	rows, err := stmt.sc.Query(stmt.query, args)
    87  	if err != nil {
    88  		stmt.setQueryIDFromError(err)
    89  		return nil, err
    90  	}
    91  	r, ok := rows.(SnowflakeRows)
    92  	if !ok {
    93  		return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows)
    94  	}
    95  	stmt.lastQueryID = r.GetQueryID()
    96  	return rows, err
    97  }
    98  
    99  func (stmt *snowflakeStmt) GetQueryID() string {
   100  	return stmt.lastQueryID
   101  }
   102  
   103  func (stmt *snowflakeStmt) setQueryIDFromError(err error) {
   104  	var snowflakeError *SnowflakeError
   105  	if errors.As(err, &snowflakeError) {
   106  		stmt.lastQueryID = snowflakeError.QueryID
   107  	}
   108  }