github.com/code-to-go/safepool.lib@v0.0.0-20221205180519-ee25e63c226e/sql/common.go (about) 1 package sql 2 3 import ( 4 "database/sql" 5 "encoding/base64" 6 "fmt" 7 "reflect" 8 "strings" 9 "time" 10 11 "github.com/sirupsen/logrus" 12 ) 13 14 var queriesCache = map[string]string{} 15 var stmtCache = map[string]*sql.Stmt{} 16 var ErrNoRows = sql.ErrNoRows 17 18 func prepareStatement(key, s string, line int) { 19 if _, ok := stmtCache[key]; ok { 20 logrus.Panicf("duplicate SQL statement for key '%s' (line %d)", s, line) 21 panic(key) 22 } 23 24 stmt, err := db.Prepare(s) 25 if err != nil { 26 logrus.Panicf("cannot compile SQL statement (%d) '%s': %v", line, s, err) 27 panic(err) 28 } 29 stmtCache[key] = stmt 30 queriesCache[key] = s 31 } 32 33 func getStatement(key string) *sql.Stmt { 34 if v, ok := stmtCache[key]; ok { 35 return v 36 } else { 37 logrus.Panicf("missing SQL statement for key '%s'", key) 38 panic(key) 39 } 40 } 41 42 type Args map[string]any 43 44 func named(m Args) []any { 45 var args []any 46 for k, v := range m { 47 args = append(args, sql.Named(k, v)) 48 } 49 return args 50 } 51 52 func trace(key string, m Args, err error) { 53 if logrus.IsLevelEnabled(logrus.InfoLevel) { 54 q := queriesCache[key] 55 for k, v := range m { 56 q = strings.ReplaceAll(q, ":"+k, fmt.Sprintf("%v", v)) 57 } 58 logrus.Infof("SQL: %s: %v", q, err) 59 } 60 } 61 62 func Exec(key string, m Args) (sql.Result, error) { 63 res, err := getStatement(key).Exec(named(m)...) 64 trace(key, m, err) 65 return res, err 66 } 67 68 func QueryRow(key string, m Args, dest ...any) error { 69 row := getStatement(key).QueryRow(named(m)...) 70 err := row.Err() 71 trace(key, m, err) 72 if err == nil { 73 return row.Scan(dest...) 74 } 75 return err 76 } 77 78 func Query(key string, m Args) (*sql.Rows, error) { 79 rows, err := getStatement(key).Query(named(m)...) 80 trace(key, m, err) 81 return rows, err 82 } 83 84 func QueryEx[T any](key string, m Args, f func(dest ...any) T) ([]T, error) { 85 rows, err := getStatement(key).Query(named(m)...) 86 if err != nil { 87 return nil, err 88 } 89 90 var res []T 91 var dest []any 92 t := reflect.TypeOf(f) 93 for i := 0; i < t.NumIn(); i++ { 94 dest = append(dest, reflect.New(t.In(i))) 95 } 96 97 for rows.Next() { 98 rows.Scan(dest...) 99 if err == nil { 100 res = append(res, f(dest...)) 101 } 102 } 103 return res, nil 104 } 105 106 func EncodeBase64(data []byte) string { 107 return base64.StdEncoding.EncodeToString(data) 108 } 109 110 func DecodeBase64(data string) []byte { 111 b, err := base64.StdEncoding.DecodeString(data) 112 if err != nil { 113 return nil 114 } 115 return b 116 } 117 118 func EncodeTime(t time.Time) int64 { 119 return t.Unix() 120 } 121 122 func DecodeTime(v int64) time.Time { 123 return time.Unix(v, 0) 124 }