github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/loggingdriver/driver.go (about)

     1  package loggingdriver
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"strconv"
     8  	"time"
     9  
    10  	"database/sql/driver"
    11  
    12  	"github.com/go-courier/logr"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  type ErrorLevel func(error error) int
    17  
    18  func Wrap(d driver.Driver, name string, errorLevel func(error error) int) driver.DriverContext {
    19  	return &loggerConnector{
    20  		driver: d,
    21  		opt: &opt{
    22  			name:       name,
    23  			errorLevel: errorLevel,
    24  		},
    25  	}
    26  }
    27  
    28  type opt struct {
    29  	name       string
    30  	errorLevel ErrorLevel
    31  }
    32  
    33  func (o opt) ErrorLevel(err error) int {
    34  	if o.errorLevel != nil {
    35  		return o.errorLevel(err)
    36  	}
    37  	return 1
    38  }
    39  
    40  type loggerConnector struct {
    41  	driver driver.Driver
    42  	opt    *opt
    43  	dsn    string
    44  }
    45  
    46  func (c *loggerConnector) OpenConnector(dsn string) (driver.Connector, error) {
    47  	return &loggerConnector{
    48  		driver: c.driver,
    49  		opt:    c.opt,
    50  		dsn:    dsn,
    51  	}, nil
    52  }
    53  
    54  func (c *loggerConnector) Connect(ctx context.Context) (driver.Conn, error) {
    55  	return c.Open(c.dsn)
    56  }
    57  
    58  func (c *loggerConnector) Driver() driver.Driver {
    59  	return c
    60  }
    61  
    62  func (c *loggerConnector) Open(dsn string) (driver.Conn, error) {
    63  	conn, err := c.driver.Open(dsn)
    64  	if err != nil {
    65  		return nil, errors.Wrapf(err, "failed to open connection")
    66  	}
    67  	return &loggerConn{Conn: conn, opt: c.opt}, nil
    68  }
    69  
    70  var _ interface {
    71  	driver.ConnBeginTx
    72  	driver.ExecerContext
    73  	driver.QueryerContext
    74  } = (*loggerConn)(nil)
    75  
    76  type loggerConn struct {
    77  	driver.Conn
    78  	opt *opt
    79  }
    80  
    81  func (c *loggerConn) Close() error {
    82  	if err := c.Conn.Close(); err != nil {
    83  		return err
    84  	}
    85  	return nil
    86  }
    87  
    88  func (c *loggerConn) Prepare(query string) (driver.Stmt, error) {
    89  	panic(fmt.Errorf("don't use Prepare"))
    90  }
    91  
    92  func (c *loggerConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
    93  	_, logger := logr.Start(ctx, "SQLQuery")
    94  	cost := startTimer()
    95  
    96  	defer func() {
    97  		q := interpolateParams(query, args)
    98  
    99  		l := logger.WithValues("driver", c.opt.name, "sql", q)
   100  		if err != nil {
   101  			if c.opt.ErrorLevel(err) > 0 {
   102  				l.Error(errors.Wrapf(err, "query failed"))
   103  			} else {
   104  				l.Warn(errors.Wrapf(err, "query failed"))
   105  			}
   106  		} else {
   107  			l.WithValues("cost", cost().String()).Debug("")
   108  		}
   109  
   110  		logger.End()
   111  	}()
   112  
   113  	rows, err = c.Conn.(driver.QueryerContext).QueryContext(context.Background(), replaceValueHolder(query), args)
   114  	return
   115  }
   116  
   117  func (c *loggerConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
   118  	cost := startTimer()
   119  	_, logger := logr.Start(ctx, "SQLExec")
   120  
   121  	defer func() {
   122  		q := interpolateParams(query, args)
   123  		l := logger.WithValues("driver", c.opt.name, "sql", q)
   124  		if err != nil {
   125  			if c.opt.ErrorLevel(err) > 0 {
   126  				l.Error(errors.Wrap(err, "exec failed"))
   127  			} else {
   128  				l.Warn(errors.Wrapf(err, "exec failed"))
   129  			}
   130  		} else {
   131  			l.WithValues("cost", cost().String()).Debug("")
   132  		}
   133  
   134  		logger.End()
   135  	}()
   136  
   137  	result, err = c.Conn.(driver.ExecerContext).ExecContext(context.Background(), replaceValueHolder(query), args)
   138  	return
   139  }
   140  
   141  func replaceValueHolder(query string) string {
   142  	index := 0
   143  	data := []byte(query)
   144  
   145  	e := bytes.NewBufferString("")
   146  
   147  	for i := range data {
   148  		c := data[i]
   149  		switch c {
   150  		case '?':
   151  			e.WriteByte('$')
   152  			e.WriteString(strconv.FormatInt(int64(index+1), 10))
   153  			index++
   154  		default:
   155  			e.WriteByte(c)
   156  		}
   157  	}
   158  
   159  	return e.String()
   160  }
   161  
   162  func startTimer() func() time.Duration {
   163  	startTime := time.Now()
   164  	return func() time.Duration {
   165  		return time.Since(startTime)
   166  	}
   167  }
   168  
   169  func (c *loggerConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   170  	logger := logr.FromContext(ctx)
   171  
   172  	logger.Debug("=========== Beginning Transaction ===========")
   173  
   174  	// don't pass ctx into real driver to avoid connect discount
   175  	tx, err := c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
   176  	if err != nil {
   177  		logger.Error(errors.Wrap(err, "failed to begin transaction"))
   178  		return nil, err
   179  	}
   180  
   181  	return &loggingTx{tx: tx, logger: logger}, nil
   182  }
   183  
   184  type loggingTx struct {
   185  	logger logr.Logger
   186  	tx     driver.Tx
   187  }
   188  
   189  func (tx *loggingTx) Commit() error {
   190  	if err := tx.tx.Commit(); err != nil {
   191  		tx.logger.Debug("failed to commit transaction: %s", err)
   192  		return err
   193  	}
   194  	tx.logger.Debug("=========== Committed Transaction ===========")
   195  	return nil
   196  }
   197  
   198  func (tx *loggingTx) Rollback() error {
   199  	if err := tx.tx.Rollback(); err != nil {
   200  		tx.logger.Debug("failed to rollback transaction: %s", err)
   201  		return err
   202  	}
   203  	tx.logger.Debug("=========== Rollback Transaction ===========")
   204  	return nil
   205  }