gitee.com/go-genie/sqlx@v1.0.3/connectors/postgresql/driver.go (about)

     1  package postgresql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"gitee.com/go-genie/sqlx"
    13  
    14  	"gitee.com/go-genie/logr"
    15  	"github.com/lib/pq"
    16  	"github.com/pkg/errors"
    17  )
    18  
    19  var _ interface {
    20  	driver.Driver
    21  } = (*PostgreSQLLoggingDriver)(nil)
    22  
    23  type PostgreSQLLoggingDriver struct {
    24  	driver pq.Driver
    25  }
    26  
    27  func (d *PostgreSQLLoggingDriver) Open(dsn string) (driver.Conn, error) {
    28  	config, err := pq.ParseURL(dsn)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	opts := FromConfigString(config)
    34  	if pass, ok := opts["password"]; ok {
    35  		opts["password"] = strings.Repeat("*", len(pass))
    36  	}
    37  
    38  	conn, err := d.driver.Open(config)
    39  	if err != nil {
    40  		return nil, errors.Wrapf(err, "failed to open connection: %s", opts)
    41  	}
    42  
    43  	return &loggerConn{Conn: conn, cfg: opts}, nil
    44  }
    45  
    46  var _ interface {
    47  	driver.ConnBeginTx
    48  	driver.ExecerContext
    49  	driver.QueryerContext
    50  } = (*loggerConn)(nil)
    51  
    52  type loggerConn struct {
    53  	cfg PostgreSQLOpts
    54  	driver.Conn
    55  }
    56  
    57  func (c *loggerConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    58  	logger := logr.FromContext(ctx)
    59  
    60  	logger.Debug("=========== Beginning Transaction ===========")
    61  	tx, err := c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
    62  	if err != nil {
    63  		logger.Error(errors.Wrap(err, "failed to begin transaction"))
    64  		return nil, err
    65  	}
    66  	return &loggingTx{tx: tx, logger: logger}, nil
    67  }
    68  
    69  func (c *loggerConn) Close() error {
    70  	if err := c.Conn.Close(); err != nil {
    71  		return err
    72  	}
    73  	return nil
    74  }
    75  
    76  func (c *loggerConn) Prepare(query string) (driver.Stmt, error) {
    77  	panic(fmt.Errorf("don't use Prepare"))
    78  }
    79  
    80  func (c *loggerConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
    81  	newCtx, logger := logr.Start(ctx, "Query")
    82  	cost := startTimer()
    83  
    84  	defer func() {
    85  		q := interpolateParams(query, args)
    86  
    87  		if err != nil {
    88  			if pgErr, ok := sqlx.UnwrapAll(err).(*pq.Error); !ok {
    89  				logger.Error(errors.Wrapf(err, "query failed: %s", q))
    90  			} else {
    91  				logger.Warn(errors.Wrapf(pgErr, "query failed: %s", q))
    92  			}
    93  		} else {
    94  			logger.WithValues("cost", cost().String()).Debug("%s", q)
    95  		}
    96  
    97  		logger.End()
    98  	}()
    99  
   100  	rows, err = c.Conn.(driver.QueryerContext).QueryContext(newCtx, replaceValueHolder(query), args)
   101  	return
   102  }
   103  
   104  func (c *loggerConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
   105  	cost := startTimer()
   106  	newCtx, logger := logr.Start(ctx, "Exec")
   107  
   108  	defer func() {
   109  		q := interpolateParams(query, args)
   110  
   111  		if err != nil {
   112  			if pgError, ok := sqlx.UnwrapAll(err).(*pq.Error); !ok {
   113  				logger.Error(errors.Wrapf(err, "exec failed: %s", q))
   114  			} else if pgError.Code == "23505" {
   115  				logger.Warn(errors.Wrapf(pgError, "exec failed: %s", q))
   116  			} else {
   117  				logger.Error(errors.Wrapf(pgError, "exec failed: %s", q))
   118  			}
   119  			return
   120  		}
   121  
   122  		logger.WithValues("cost", cost().String()).Debug(q.String())
   123  
   124  		logger.End()
   125  	}()
   126  
   127  	result, err = c.Conn.(driver.ExecerContext).ExecContext(newCtx, replaceValueHolder(query), args)
   128  	return
   129  }
   130  
   131  func replaceValueHolder(query string) string {
   132  	index := 0
   133  	data := []byte(query)
   134  
   135  	e := bytes.NewBufferString("")
   136  
   137  	for i := range data {
   138  		c := data[i]
   139  		switch c {
   140  		case '?':
   141  			e.WriteByte('$')
   142  			e.WriteString(strconv.FormatInt(int64(index+1), 10))
   143  			index++
   144  		default:
   145  			e.WriteByte(c)
   146  		}
   147  	}
   148  
   149  	return e.String()
   150  }
   151  
   152  func startTimer() func() time.Duration {
   153  	startTime := time.Now()
   154  	return func() time.Duration {
   155  		return time.Since(startTime)
   156  	}
   157  }
   158  
   159  type loggingTx struct {
   160  	logger logr.Logger
   161  	tx     driver.Tx
   162  }
   163  
   164  func (tx *loggingTx) Commit() error {
   165  	if err := tx.tx.Commit(); err != nil {
   166  		tx.logger.Debug("failed to commit transaction: %s", err)
   167  		return err
   168  	}
   169  	tx.logger.Debug("=========== Committed Transaction ===========")
   170  	return nil
   171  }
   172  
   173  func (tx *loggingTx) Rollback() error {
   174  	if err := tx.tx.Rollback(); err != nil {
   175  		tx.logger.Debug("failed to rollback transaction: %s", err)
   176  		return err
   177  	}
   178  	tx.logger.Debug("=========== Rollback Transaction ===========")
   179  	return nil
   180  }