github.com/snowflakedb/gosnowflake@v1.9.0/statement.go (about) 1 // Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "database/sql/driver" 8 "errors" 9 "fmt" 10 ) 11 12 // SnowflakeStmt represents the prepared statement in driver. 13 type SnowflakeStmt interface { 14 GetQueryID() string 15 } 16 17 type snowflakeStmt struct { 18 sc *snowflakeConn 19 query string 20 lastQueryID string 21 } 22 23 func (stmt *snowflakeStmt) Close() error { 24 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Close") 25 // noop 26 return nil 27 } 28 29 func (stmt *snowflakeStmt) NumInput() int { 30 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.NumInput") 31 // Go Snowflake doesn't know the number of binding parameters. 32 return -1 33 } 34 35 func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 36 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") 37 return stmt.execInternal(ctx, args) 38 } 39 40 func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 41 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext") 42 rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) 43 if err != nil { 44 stmt.setQueryIDFromError(err) 45 return nil, err 46 } 47 r, ok := rows.(SnowflakeRows) 48 if !ok { 49 return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) 50 } 51 stmt.lastQueryID = r.GetQueryID() 52 return rows, nil 53 } 54 55 func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { 56 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") 57 return stmt.execInternal(context.Background(), toNamedValues(args)) 58 } 59 60 func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 61 logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal") 62 if ctx == nil { 63 ctx = context.Background() 64 } 65 stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement) 66 result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args) 67 if err != nil { 68 stmt.setQueryIDFromError(err) 69 return nil, err 70 } 71 rnr, ok := result.(*snowflakeResultNoRows) 72 if ok { 73 stmt.lastQueryID = rnr.GetQueryID() 74 return driver.ResultNoRows, nil 75 } 76 r, ok := result.(SnowflakeResult) 77 if !ok { 78 return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) 79 } 80 stmt.lastQueryID = r.GetQueryID() 81 return result, err 82 } 83 84 func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { 85 logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query") 86 rows, err := stmt.sc.Query(stmt.query, args) 87 if err != nil { 88 stmt.setQueryIDFromError(err) 89 return nil, err 90 } 91 r, ok := rows.(SnowflakeRows) 92 if !ok { 93 return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) 94 } 95 stmt.lastQueryID = r.GetQueryID() 96 return rows, err 97 } 98 99 func (stmt *snowflakeStmt) GetQueryID() string { 100 return stmt.lastQueryID 101 } 102 103 func (stmt *snowflakeStmt) setQueryIDFromError(err error) { 104 var snowflakeError *SnowflakeError 105 if errors.As(err, &snowflakeError) { 106 stmt.lastQueryID = snowflakeError.QueryID 107 } 108 }