github.com/go-courier/sqlx/v2@v2.23.13/connectors/postgresql/driver.go (about) 1 package postgresql 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/go-courier/sqlx/v2" 13 14 "github.com/go-courier/logr" 15 "github.com/lib/pq" 16 "github.com/pkg/errors" 17 ) 18 19 var _ interface { 20 driver.Driver 21 } = (*PostgreSQLLoggingDriver)(nil) 22 23 type PostgreSQLLoggingDriver struct { 24 driver pq.Driver 25 } 26 27 func (d *PostgreSQLLoggingDriver) Open(dsn string) (driver.Conn, error) { 28 config, err := pq.ParseURL(dsn) 29 if err != nil { 30 return nil, err 31 } 32 33 opts := FromConfigString(config) 34 if pass, ok := opts["password"]; ok { 35 opts["password"] = strings.Repeat("*", len(pass)) 36 } 37 38 conn, err := d.driver.Open(config) 39 if err != nil { 40 return nil, errors.Wrapf(err, "failed to open connection: %s", opts) 41 } 42 43 return &loggerConn{Conn: conn, cfg: opts}, nil 44 } 45 46 var _ interface { 47 driver.ConnBeginTx 48 driver.ExecerContext 49 driver.QueryerContext 50 } = (*loggerConn)(nil) 51 52 type loggerConn struct { 53 cfg PostgreSQLOpts 54 driver.Conn 55 } 56 57 func (c *loggerConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 58 logger := logr.FromContext(ctx) 59 60 logger.Debug("=========== Beginning Transaction ===========") 61 tx, err := c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) 62 if err != nil { 63 logger.Error(errors.Wrap(err, "failed to begin transaction")) 64 return nil, err 65 } 66 return &loggingTx{tx: tx, logger: logger}, nil 67 } 68 69 func (c *loggerConn) Close() error { 70 if err := c.Conn.Close(); err != nil { 71 return err 72 } 73 return nil 74 } 75 76 func (c *loggerConn) Prepare(query string) (driver.Stmt, error) { 77 panic(fmt.Errorf("don't use Prepare")) 78 } 79 80 func (c *loggerConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { 81 newCtx, logger := logr.Start(ctx, "Query") 82 cost := startTimer() 83 84 defer func() { 85 q := interpolateParams(query, args) 86 87 if err != nil { 88 if pgErr, ok := sqlx.UnwrapAll(err).(*pq.Error); !ok { 89 logger.Error(errors.Wrapf(err, "query failed: %s", q)) 90 } else { 91 logger.Warn(errors.Wrapf(pgErr, "query failed: %s", q)) 92 } 93 } else { 94 logger.WithValues("cost", cost().String()).Debug("%s", q) 95 } 96 97 logger.End() 98 }() 99 100 rows, err = c.Conn.(driver.QueryerContext).QueryContext(newCtx, replaceValueHolder(query), args) 101 return 102 } 103 104 func (c *loggerConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { 105 cost := startTimer() 106 newCtx, logger := logr.Start(ctx, "Exec") 107 108 defer func() { 109 q := interpolateParams(query, args) 110 111 if err != nil { 112 if pgError, ok := sqlx.UnwrapAll(err).(*pq.Error); !ok { 113 logger.Error(errors.Wrapf(err, "exec failed: %s", q)) 114 } else if pgError.Code == "23505" { 115 logger.Warn(errors.Wrapf(pgError, "exec failed: %s", q)) 116 } else { 117 logger.Error(errors.Wrapf(pgError, "exec failed: %s", q)) 118 } 119 return 120 } 121 122 logger.WithValues("cost", cost().String()).Debug(q.String()) 123 124 logger.End() 125 }() 126 127 result, err = c.Conn.(driver.ExecerContext).ExecContext(newCtx, replaceValueHolder(query), args) 128 return 129 } 130 131 func replaceValueHolder(query string) string { 132 index := 0 133 data := []byte(query) 134 135 e := bytes.NewBufferString("") 136 137 for i := range data { 138 c := data[i] 139 switch c { 140 case '?': 141 e.WriteByte('$') 142 e.WriteString(strconv.FormatInt(int64(index+1), 10)) 143 index++ 144 default: 145 e.WriteByte(c) 146 } 147 } 148 149 return e.String() 150 } 151 152 func startTimer() func() time.Duration { 153 startTime := time.Now() 154 return func() time.Duration { 155 return time.Since(startTime) 156 } 157 } 158 159 type loggingTx struct { 160 logger logr.Logger 161 tx driver.Tx 162 } 163 164 func (tx *loggingTx) Commit() error { 165 if err := tx.tx.Commit(); err != nil { 166 tx.logger.Debug("failed to commit transaction: %s", err) 167 return err 168 } 169 tx.logger.Debug("=========== Committed Transaction ===========") 170 return nil 171 } 172 173 func (tx *loggingTx) Rollback() error { 174 if err := tx.tx.Rollback(); err != nil { 175 tx.logger.Debug("failed to rollback transaction: %s", err) 176 return err 177 } 178 tx.logger.Debug("=========== Rollback Transaction ===========") 179 return nil 180 }