github.com/code-to-go/safepool.lib@v0.0.0-20221205180519-ee25e63c226e/sql/common.go (about)

     1  package sql
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/sirupsen/logrus"
    12  )
    13  
    14  var queriesCache = map[string]string{}
    15  var stmtCache = map[string]*sql.Stmt{}
    16  var ErrNoRows = sql.ErrNoRows
    17  
    18  func prepareStatement(key, s string, line int) {
    19  	if _, ok := stmtCache[key]; ok {
    20  		logrus.Panicf("duplicate SQL statement for key '%s' (line %d)", s, line)
    21  		panic(key)
    22  	}
    23  
    24  	stmt, err := db.Prepare(s)
    25  	if err != nil {
    26  		logrus.Panicf("cannot compile SQL statement (%d) '%s': %v", line, s, err)
    27  		panic(err)
    28  	}
    29  	stmtCache[key] = stmt
    30  	queriesCache[key] = s
    31  }
    32  
    33  func getStatement(key string) *sql.Stmt {
    34  	if v, ok := stmtCache[key]; ok {
    35  		return v
    36  	} else {
    37  		logrus.Panicf("missing SQL statement for key '%s'", key)
    38  		panic(key)
    39  	}
    40  }
    41  
    42  type Args map[string]any
    43  
    44  func named(m Args) []any {
    45  	var args []any
    46  	for k, v := range m {
    47  		args = append(args, sql.Named(k, v))
    48  	}
    49  	return args
    50  }
    51  
    52  func trace(key string, m Args, err error) {
    53  	if logrus.IsLevelEnabled(logrus.InfoLevel) {
    54  		q := queriesCache[key]
    55  		for k, v := range m {
    56  			q = strings.ReplaceAll(q, ":"+k, fmt.Sprintf("%v", v))
    57  		}
    58  		logrus.Infof("SQL: %s: %v", q, err)
    59  	}
    60  }
    61  
    62  func Exec(key string, m Args) (sql.Result, error) {
    63  	res, err := getStatement(key).Exec(named(m)...)
    64  	trace(key, m, err)
    65  	return res, err
    66  }
    67  
    68  func QueryRow(key string, m Args, dest ...any) error {
    69  	row := getStatement(key).QueryRow(named(m)...)
    70  	err := row.Err()
    71  	trace(key, m, err)
    72  	if err == nil {
    73  		return row.Scan(dest...)
    74  	}
    75  	return err
    76  }
    77  
    78  func Query(key string, m Args) (*sql.Rows, error) {
    79  	rows, err := getStatement(key).Query(named(m)...)
    80  	trace(key, m, err)
    81  	return rows, err
    82  }
    83  
    84  func QueryEx[T any](key string, m Args, f func(dest ...any) T) ([]T, error) {
    85  	rows, err := getStatement(key).Query(named(m)...)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	var res []T
    91  	var dest []any
    92  	t := reflect.TypeOf(f)
    93  	for i := 0; i < t.NumIn(); i++ {
    94  		dest = append(dest, reflect.New(t.In(i)))
    95  	}
    96  
    97  	for rows.Next() {
    98  		rows.Scan(dest...)
    99  		if err == nil {
   100  			res = append(res, f(dest...))
   101  		}
   102  	}
   103  	return res, nil
   104  }
   105  
   106  func EncodeBase64(data []byte) string {
   107  	return base64.StdEncoding.EncodeToString(data)
   108  }
   109  
   110  func DecodeBase64(data string) []byte {
   111  	b, err := base64.StdEncoding.DecodeString(data)
   112  	if err != nil {
   113  		return nil
   114  	}
   115  	return b
   116  }
   117  
   118  func EncodeTime(t time.Time) int64 {
   119  	return t.Unix()
   120  }
   121  
   122  func DecodeTime(v int64) time.Time {
   123  	return time.Unix(v, 0)
   124  }