code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/connection_tx.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "fmt" 21 "io" 22 "sync" 23 "sync/atomic" 24 25 "code.vegaprotocol.io/vega/logging" 26 27 "github.com/jackc/pgconn" 28 "github.com/jackc/pgx/v4" 29 "github.com/jackc/pgx/v4/pgxpool" 30 "github.com/pkg/errors" 31 ) 32 33 type ConnectionSource struct { 34 pool *pgxpool.Pool 35 log *logging.Logger 36 isTest bool 37 } 38 39 type wrappedTx struct { 40 parent *wrappedTx 41 mu sync.Mutex 42 postHooks []func() 43 id int64 44 idgen *atomic.Int64 45 tx pgx.Tx 46 subTx map[int64]*wrappedTx 47 } 48 49 type ( 50 txKey struct{} 51 connKey struct{} 52 ) 53 54 func NewTransactionalConnectionSource(ctx context.Context, log *logging.Logger, connConfig ConnectionConfig) (*ConnectionSource, error) { 55 pool, err := CreateConnectionPool(ctx, connConfig) 56 if err != nil { 57 return nil, fmt.Errorf("failed to create connection pool: %w", err) 58 } 59 return &ConnectionSource{ 60 pool: pool, 61 log: log.Named("connection-source"), 62 }, nil 63 } 64 65 func (c *ConnectionSource) ToggleTest() { 66 c.isTest = true 67 } 68 69 func (c *ConnectionSource) WithConnection(ctx context.Context) (context.Context, error) { 70 poolConn, err := c.pool.Acquire(ctx) 71 if err != nil { 72 return context.Background(), errors.Errorf("failed to acquire connection:%s", err) 73 } 74 return context.WithValue(ctx, connKey{}, &wrappedConn{ 75 Conn: poolConn.Hijack(), 76 }), nil 77 } 78 79 func (c *ConnectionSource) WithTransaction(ctx context.Context) (context.Context, error) { 80 var tx pgx.Tx 81 var err error 82 nTx := &wrappedTx{ 83 postHooks: []func(){}, 84 subTx: map[int64]*wrappedTx{}, 85 idgen: &atomic.Int64{}, 86 } 87 // start id at 0 88 nTx.idgen.Store(0) 89 if ctxTx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 90 // register sub-transactions 91 nTx.id = ctxTx.idgen.Add(1) 92 tx, err = ctxTx.tx.Begin(ctx) 93 nTx.parent = ctxTx 94 if err == nil { 95 ctxTx.mu.Lock() 96 ctxTx.subTx[nTx.id] = nTx 97 ctxTx.mu.Unlock() 98 } 99 } else if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 100 tx, err = conn.Begin(ctx) 101 } else { 102 tx, err = c.pool.Begin(ctx) 103 } 104 if err != nil { 105 return ctx, errors.Wrapf(err, "failed to start transaction:%s", err) 106 } 107 nTx.tx = tx 108 return context.WithValue(ctx, txKey{}, nTx), nil 109 } 110 111 func (c *ConnectionSource) AfterCommit(ctx context.Context, f func()) { 112 // if the context references an ongoing transaction, append the callback to be invoked on commit 113 if cTx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 114 cTx.mu.Lock() 115 cTx.postHooks = append(cTx.postHooks, f) 116 cTx.mu.Unlock() 117 return 118 } 119 // not in transaction, just call immediately. 120 f() 121 } 122 123 func (c *ConnectionSource) Rollback(ctx context.Context) error { 124 // if we're in a transaction, roll it back starting with the sub-transactions. 125 tx, ok := ctx.Value(txKey{}).(*wrappedTx) 126 if !ok { 127 // no tx ongoing 128 return fmt.Errorf("no transaction is associated with the context") 129 } 130 return tx.Rollback(ctx) 131 } 132 133 func (c *ConnectionSource) Commit(ctx context.Context) error { 134 tx, ok := ctx.Value(txKey{}).(*wrappedTx) 135 if !ok { 136 return fmt.Errorf("no transaction is associated with the context") 137 } 138 tx.mu.Lock() 139 defer tx.mu.Unlock() 140 post, err := tx.commit(ctx) 141 if err != nil { 142 return fmt.Errorf("failed to commit transaction for context: %s, error: %w", ctx, err) 143 } 144 // invoke all post-commit hooks once the transaction (and its sub transactions) have been committed 145 // make an exception for unit tests, so we don't need to commit DB transactions for hooks on the nested transaction. 146 if !c.isTest && tx.parent != nil { 147 // this is a nested transaction, don't invoke hooks until the parent is committed 148 // instead prepend the hooks and return. 149 tx.parent.mu.Lock() 150 tx.parent.postHooks = append(post, tx.parent.postHooks...) 151 // remove the reference to this transaction from its parent 152 delete(tx.parent.subTx, tx.id) 153 tx.parent.mu.Unlock() 154 return nil 155 } 156 // this is the main transactions, invoke all hooks now 157 for _, f := range post { 158 f() 159 } 160 if tx.parent != nil { 161 tx.parent.mu.Lock() 162 delete(tx.parent.subTx, tx.id) 163 tx.parent.mu.Unlock() 164 } 165 return nil 166 } 167 168 func (c *ConnectionSource) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { 169 // this is nasty, but required for the API tests currently. 170 if c.isTest && c.pool == nil { 171 return nil, pgx.ErrNoRows 172 } 173 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 174 return tx.tx.Query(ctx, sql, args...) 175 } 176 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 177 return conn.Query(ctx, sql, args...) 178 } 179 return c.pool.Query(ctx, sql, args...) 180 } 181 182 func (c *ConnectionSource) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { 183 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 184 return tx.tx.QueryRow(ctx, sql, args...) 185 } 186 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 187 return conn.QueryRow(ctx, sql, args...) 188 } 189 return c.pool.QueryRow(ctx, sql, args...) 190 } 191 192 func (c *ConnectionSource) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { 193 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 194 return tx.tx.QueryFunc(ctx, sql, args, scans, f) 195 } 196 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 197 return conn.QueryFunc(ctx, sql, args, scans, f) 198 } 199 return c.pool.QueryFunc(ctx, sql, args, scans, f) 200 } 201 202 func (c *ConnectionSource) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 203 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 204 return tx.tx.SendBatch(ctx, b) 205 } 206 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 207 return conn.SendBatch(ctx, b) 208 } 209 return c.pool.SendBatch(ctx, b) 210 } 211 212 func (c *ConnectionSource) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 213 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 214 return tx.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) 215 } 216 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 217 return conn.CopyFrom(ctx, tableName, columnNames, rowSrc) 218 } 219 return c.pool.CopyFrom(ctx, tableName, columnNames, rowSrc) 220 } 221 222 func (c *ConnectionSource) CopyTo(ctx context.Context, w io.Writer, sql string, args ...any) (pgconn.CommandTag, error) { 223 // this is nasty, but required for the API tests currently. 224 if c.isTest && c.pool == nil { 225 return pgconn.CommandTag{}, nil 226 } 227 var err error 228 sql, err = SanitizeSql(sql, args...) 229 if err != nil { 230 return nil, fmt.Errorf("failed to sanitize sql: %w", err) 231 } 232 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 233 return tx.tx.Conn().PgConn().CopyTo(ctx, w, sql) 234 } 235 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 236 return conn.PgConn().CopyTo(ctx, w, sql) 237 } 238 conn, err := c.pool.Acquire(ctx) 239 if err != nil { 240 return nil, err 241 } 242 defer conn.Release() 243 return conn.Conn().PgConn().CopyTo(ctx, w, sql) 244 } 245 246 func (c *ConnectionSource) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { 247 if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok { 248 return tx.tx.Exec(ctx, sql, args...) 249 } 250 if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok { 251 return conn.Exec(ctx, sql, args...) 252 } 253 return c.pool.Exec(ctx, sql, args...) 254 } 255 256 type wrappedConn struct { 257 *pgx.Conn 258 } 259 260 func (c *ConnectionSource) RefreshMaterializedViews(ctx context.Context) error { 261 conn := ctx.Value(connKey{}).(*wrappedConn) 262 materializedViewsToRefresh := []struct { 263 name string 264 concurrently bool 265 }{ 266 {"game_stats", false}, 267 {"game_stats_current", false}, 268 } 269 270 for _, view := range materializedViewsToRefresh { 271 sql := "REFRESH MATERIALIZED VIEW " 272 if view.concurrently { 273 sql += "CONCURRENTLY " 274 } 275 sql += view.name 276 277 _, err := conn.Exec(ctx, sql) 278 if err != nil { 279 return fmt.Errorf("failed to refresh materialized view %s: %w", view.name, err) 280 } 281 } 282 return nil 283 } 284 285 func (c *ConnectionSource) Close() { 286 c.pool.Close() 287 } 288 289 func (c *ConnectionSource) wrapE(err error) error { 290 return wrapE(err) 291 } 292 293 func (t *wrappedTx) commit(ctx context.Context) ([]func(), error) { 294 // return callbacks so we only invoke them if no errors occurred 295 ret := t.postHooks 296 for id, sTx := range t.subTx { 297 // acquire the lock, release it as soon as possible 298 sTx.mu.Lock() 299 subCB, err := sTx.commit(ctx) 300 if err != nil { 301 sTx.mu.Unlock() 302 return nil, err 303 } 304 sTx.mu.Unlock() 305 delete(t.subTx, id) 306 // prepend callbacks from sub transactions 307 ret = append(subCB, ret...) 308 } 309 // actually commit this transaction 310 if err := t.tx.Commit(ctx); err != nil { 311 return nil, err 312 } 313 return ret, nil 314 } 315 316 func (t *wrappedTx) Rollback(ctx context.Context) error { 317 for _, sTx := range t.subTx { 318 if err := sTx.Rollback(ctx); err != nil { 319 return err 320 } 321 } 322 if err := t.tx.Rollback(ctx); err != nil { 323 return fmt.Errorf("failed to rollback transaction for context:%s, error:%w", ctx, err) 324 } 325 if t.parent != nil { 326 t.parent.rmSubTx(t.id) 327 } 328 return nil 329 } 330 331 func (t *wrappedTx) rmSubTx(id int64) { 332 t.mu.Lock() 333 defer t.mu.Unlock() 334 // this is called from Rollback, which is recursive already. 335 // no need to recursively remove the sub-tx 336 delete(t.subTx, id) 337 }