github.com/hashicorp/vault/sdk@v0.11.0/helper/dbtxn/dbtxn.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbtxn 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "strings" 11 ) 12 13 // ExecuteDBQuery handles executing one single statement while properly releasing its resources. 14 // - ctx: Required 15 // - db: Required 16 // - config: Optional, may be nil 17 // - query: Required 18 func ExecuteDBQuery(ctx context.Context, db *sql.DB, params map[string]string, query string) error { 19 parsedQuery := parseQuery(params, query) 20 21 stmt, err := db.PrepareContext(ctx, parsedQuery) 22 if err != nil { 23 return err 24 } 25 defer stmt.Close() 26 27 return execute(ctx, stmt) 28 } 29 30 // ExecuteDBQueryDirect handles executing one single statement without preparing the query 31 // before executing it, which can be more efficient. 32 // - ctx: Required 33 // - db: Required 34 // - config: Optional, may be nil 35 // - query: Required 36 func ExecuteDBQueryDirect(ctx context.Context, db *sql.DB, params map[string]string, query string) error { 37 parsedQuery := parseQuery(params, query) 38 _, err := db.ExecContext(ctx, parsedQuery) 39 return err 40 } 41 42 // ExecuteTxQuery handles executing one single statement while properly releasing its resources. 43 // - ctx: Required 44 // - tx: Required 45 // - config: Optional, may be nil 46 // - query: Required 47 func ExecuteTxQuery(ctx context.Context, tx *sql.Tx, params map[string]string, query string) error { 48 parsedQuery := parseQuery(params, query) 49 50 stmt, err := tx.PrepareContext(ctx, parsedQuery) 51 if err != nil { 52 return err 53 } 54 defer stmt.Close() 55 56 return execute(ctx, stmt) 57 } 58 59 // ExecuteTxQueryDirect handles executing one single statement. 60 // - ctx: Required 61 // - tx: Required 62 // - config: Optional, may be nil 63 // - query: Required 64 func ExecuteTxQueryDirect(ctx context.Context, tx *sql.Tx, params map[string]string, query string) error { 65 parsedQuery := parseQuery(params, query) 66 _, err := tx.ExecContext(ctx, parsedQuery) 67 return err 68 } 69 70 func execute(ctx context.Context, stmt *sql.Stmt) error { 71 if _, err := stmt.ExecContext(ctx); err != nil { 72 return err 73 } 74 return nil 75 } 76 77 func parseQuery(m map[string]string, tpl string) string { 78 if m == nil || len(m) <= 0 { 79 return tpl 80 } 81 82 for k, v := range m { 83 tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v) 84 } 85 return tpl 86 }