github.com/eden-framework/sqlx@v0.0.2/postgresqlconnector/driver.go (about)

     1  package postgresqlconnector
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/lib/pq"
    14  	"github.com/sirupsen/logrus"
    15  )
    16  
    17  var _ interface {
    18  	driver.Driver
    19  } = (*PostgreSQLLoggingDriver)(nil)
    20  
    21  type PostgreSQLLoggingDriver struct {
    22  	Logger *logrus.Logger
    23  	Driver *pq.Driver
    24  }
    25  
    26  func FromConfigString(s string) PostgreSQLOpts {
    27  	opts := PostgreSQLOpts{}
    28  	for _, kv := range strings.Split(s, " ") {
    29  		kvs := strings.Split(kv, "=")
    30  		if len(kvs) > 1 {
    31  			opts[kvs[0]] = kvs[1]
    32  		}
    33  	}
    34  	return opts
    35  }
    36  
    37  type PostgreSQLOpts map[string]string
    38  
    39  func (opts PostgreSQLOpts) String() string {
    40  	buf := bytes.NewBuffer(nil)
    41  
    42  	kvs := make([]string, 0)
    43  	for k := range opts {
    44  		kvs = append(kvs, k)
    45  	}
    46  	sort.Strings(kvs)
    47  
    48  	for i, k := range kvs {
    49  		if i > 0 {
    50  			buf.WriteByte(' ')
    51  		}
    52  		buf.WriteString(k)
    53  		buf.WriteByte('=')
    54  		buf.WriteString(opts[k])
    55  	}
    56  
    57  	return buf.String()
    58  }
    59  
    60  func (d *PostgreSQLLoggingDriver) Open(dsn string) (driver.Conn, error) {
    61  	conf, err := pq.ParseURL(dsn)
    62  	if err != nil {
    63  		panic(err)
    64  	}
    65  	opts := FromConfigString(conf)
    66  	if pass, ok := opts["password"]; ok {
    67  		opts["password"] = strings.Repeat("*", len(pass))
    68  	}
    69  	conn, err := d.Driver.Open(conf)
    70  	if err != nil {
    71  		d.Logger.Errorf("failed to open connection: %s %s", opts, err)
    72  		return nil, err
    73  	}
    74  	d.Logger.Debugf("connected %s", opts)
    75  	return &loggerConn{Conn: conn, cfg: opts, logger: d.Logger.WithField("driver", "postgres")}, nil
    76  }
    77  
    78  var _ interface {
    79  	driver.ConnBeginTx
    80  	driver.ExecerContext
    81  	driver.QueryerContext
    82  } = (*loggerConn)(nil)
    83  
    84  type loggerConn struct {
    85  	logger *logrus.Entry
    86  	cfg    PostgreSQLOpts
    87  	driver.Conn
    88  }
    89  
    90  func (c *loggerConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    91  	logger := c.logger.WithContext(ctx)
    92  	logger.Debug("=========== Beginning Transaction ===========")
    93  	tx, err := c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
    94  	if err != nil {
    95  		logger.Errorf("failed to begin transaction: %s", err)
    96  		return nil, err
    97  	}
    98  	return &loggingTx{tx: tx, logger: logger}, nil
    99  }
   100  
   101  func (c *loggerConn) Close() error {
   102  	if err := c.Conn.Close(); err != nil {
   103  		c.logger.Errorf("failed to close connection: %s", err)
   104  		return err
   105  	}
   106  	return nil
   107  }
   108  
   109  func (c *loggerConn) Prepare(query string) (driver.Stmt, error) {
   110  	panic(fmt.Errorf("don't use Prepare"))
   111  }
   112  
   113  func (c *loggerConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
   114  	cost := startTimer()
   115  	logger := c.logger.WithContext(ctx)
   116  
   117  	defer func() {
   118  		q := interpolateParams(query, args)
   119  
   120  		if err != nil {
   121  			if pgErr, ok := err.(*pq.Error); !ok {
   122  				logger.Errorf("failed query %s: %s", err, q)
   123  			} else {
   124  				logger.Warnf("failed query %s: %s", pgErr, q)
   125  			}
   126  		} else {
   127  			logger.WithField("cost", cost().String()).Debugf("%s", q)
   128  		}
   129  	}()
   130  
   131  	rows, err = c.Conn.(driver.QueryerContext).QueryContext(ctx, replaceValueHolder(query), args)
   132  	return
   133  }
   134  
   135  func (c *loggerConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
   136  	cost := startTimer()
   137  	logger := c.logger.WithContext(ctx)
   138  
   139  	defer func() {
   140  		q := interpolateParams(query, args)
   141  
   142  		if err != nil {
   143  			if pgError, ok := err.(*pq.Error); !ok {
   144  				logger.Errorf("failed exec %s: %s", err, q)
   145  			} else if pgError.Code == "23505" {
   146  				logger.Warnf("failed exec %s: %s", err, q)
   147  			} else {
   148  				logger.Errorf("failed exec %s: %s", pgError, q)
   149  			}
   150  			return
   151  		}
   152  
   153  		logger.WithField("cost", cost().String()).Debugf("%s", q)
   154  	}()
   155  
   156  	result, err = c.Conn.(driver.ExecerContext).ExecContext(ctx, replaceValueHolder(query), args)
   157  	return
   158  }
   159  
   160  func replaceValueHolder(query string) string {
   161  	index := 0
   162  	data := []byte(query)
   163  
   164  	e := bytes.NewBufferString("")
   165  
   166  	for i := range data {
   167  		c := data[i]
   168  		switch c {
   169  		case '?':
   170  			e.WriteByte('$')
   171  			e.WriteString(strconv.FormatInt(int64(index+1), 10))
   172  			index++
   173  		default:
   174  			e.WriteByte(c)
   175  		}
   176  	}
   177  
   178  	return e.String()
   179  }
   180  
   181  func startTimer() func() time.Duration {
   182  	startTime := time.Now()
   183  	return func() time.Duration {
   184  		return time.Now().Sub(startTime)
   185  	}
   186  }
   187  
   188  type loggingTx struct {
   189  	logger *logrus.Entry
   190  	tx     driver.Tx
   191  }
   192  
   193  func (tx *loggingTx) Commit() error {
   194  	if err := tx.tx.Commit(); err != nil {
   195  		tx.logger.Debugf("failed to commit transaction: %s", err)
   196  		return err
   197  	}
   198  	tx.logger.Debug("=========== Committed Transaction ===========")
   199  	return nil
   200  }
   201  
   202  func (tx *loggingTx) Rollback() error {
   203  	if err := tx.tx.Rollback(); err != nil {
   204  		tx.logger.Debugf("failed to rollback transaction: %s", err)
   205  		return err
   206  	}
   207  	tx.logger.Debug("=========== Rollback Transaction ===========")
   208  	return nil
   209  }