code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/connection_tx.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"io"
    22  	"sync"
    23  	"sync/atomic"
    24  
    25  	"code.vegaprotocol.io/vega/logging"
    26  
    27  	"github.com/jackc/pgconn"
    28  	"github.com/jackc/pgx/v4"
    29  	"github.com/jackc/pgx/v4/pgxpool"
    30  	"github.com/pkg/errors"
    31  )
    32  
    33  type ConnectionSource struct {
    34  	pool   *pgxpool.Pool
    35  	log    *logging.Logger
    36  	isTest bool
    37  }
    38  
    39  type wrappedTx struct {
    40  	parent    *wrappedTx
    41  	mu        sync.Mutex
    42  	postHooks []func()
    43  	id        int64
    44  	idgen     *atomic.Int64
    45  	tx        pgx.Tx
    46  	subTx     map[int64]*wrappedTx
    47  }
    48  
    49  type (
    50  	txKey   struct{}
    51  	connKey struct{}
    52  )
    53  
    54  func NewTransactionalConnectionSource(ctx context.Context, log *logging.Logger, connConfig ConnectionConfig) (*ConnectionSource, error) {
    55  	pool, err := CreateConnectionPool(ctx, connConfig)
    56  	if err != nil {
    57  		return nil, fmt.Errorf("failed to create connection pool: %w", err)
    58  	}
    59  	return &ConnectionSource{
    60  		pool: pool,
    61  		log:  log.Named("connection-source"),
    62  	}, nil
    63  }
    64  
    65  func (c *ConnectionSource) ToggleTest() {
    66  	c.isTest = true
    67  }
    68  
    69  func (c *ConnectionSource) WithConnection(ctx context.Context) (context.Context, error) {
    70  	poolConn, err := c.pool.Acquire(ctx)
    71  	if err != nil {
    72  		return context.Background(), errors.Errorf("failed to acquire connection:%s", err)
    73  	}
    74  	return context.WithValue(ctx, connKey{}, &wrappedConn{
    75  		Conn: poolConn.Hijack(),
    76  	}), nil
    77  }
    78  
    79  func (c *ConnectionSource) WithTransaction(ctx context.Context) (context.Context, error) {
    80  	var tx pgx.Tx
    81  	var err error
    82  	nTx := &wrappedTx{
    83  		postHooks: []func(){},
    84  		subTx:     map[int64]*wrappedTx{},
    85  		idgen:     &atomic.Int64{},
    86  	}
    87  	// start id at 0
    88  	nTx.idgen.Store(0)
    89  	if ctxTx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
    90  		// register sub-transactions
    91  		nTx.id = ctxTx.idgen.Add(1)
    92  		tx, err = ctxTx.tx.Begin(ctx)
    93  		nTx.parent = ctxTx
    94  		if err == nil {
    95  			ctxTx.mu.Lock()
    96  			ctxTx.subTx[nTx.id] = nTx
    97  			ctxTx.mu.Unlock()
    98  		}
    99  	} else if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   100  		tx, err = conn.Begin(ctx)
   101  	} else {
   102  		tx, err = c.pool.Begin(ctx)
   103  	}
   104  	if err != nil {
   105  		return ctx, errors.Wrapf(err, "failed to start transaction:%s", err)
   106  	}
   107  	nTx.tx = tx
   108  	return context.WithValue(ctx, txKey{}, nTx), nil
   109  }
   110  
   111  func (c *ConnectionSource) AfterCommit(ctx context.Context, f func()) {
   112  	// if the context references an ongoing transaction, append the callback to be invoked on commit
   113  	if cTx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   114  		cTx.mu.Lock()
   115  		cTx.postHooks = append(cTx.postHooks, f)
   116  		cTx.mu.Unlock()
   117  		return
   118  	}
   119  	// not in transaction, just call immediately.
   120  	f()
   121  }
   122  
   123  func (c *ConnectionSource) Rollback(ctx context.Context) error {
   124  	// if we're in a transaction, roll it back starting with the sub-transactions.
   125  	tx, ok := ctx.Value(txKey{}).(*wrappedTx)
   126  	if !ok {
   127  		// no tx ongoing
   128  		return fmt.Errorf("no transaction is associated with the context")
   129  	}
   130  	return tx.Rollback(ctx)
   131  }
   132  
   133  func (c *ConnectionSource) Commit(ctx context.Context) error {
   134  	tx, ok := ctx.Value(txKey{}).(*wrappedTx)
   135  	if !ok {
   136  		return fmt.Errorf("no transaction is associated with the context")
   137  	}
   138  	tx.mu.Lock()
   139  	defer tx.mu.Unlock()
   140  	post, err := tx.commit(ctx)
   141  	if err != nil {
   142  		return fmt.Errorf("failed to commit transaction for context: %s, error: %w", ctx, err)
   143  	}
   144  	// invoke all post-commit hooks once the transaction (and its sub transactions) have been committed
   145  	// make an exception for unit tests, so we don't need to commit DB transactions for hooks on the nested transaction.
   146  	if !c.isTest && tx.parent != nil {
   147  		// this is a nested transaction, don't invoke hooks until the parent is committed
   148  		// instead prepend the hooks and return.
   149  		tx.parent.mu.Lock()
   150  		tx.parent.postHooks = append(post, tx.parent.postHooks...)
   151  		// remove the reference to this transaction from its parent
   152  		delete(tx.parent.subTx, tx.id)
   153  		tx.parent.mu.Unlock()
   154  		return nil
   155  	}
   156  	// this is the main transactions, invoke all hooks now
   157  	for _, f := range post {
   158  		f()
   159  	}
   160  	if tx.parent != nil {
   161  		tx.parent.mu.Lock()
   162  		delete(tx.parent.subTx, tx.id)
   163  		tx.parent.mu.Unlock()
   164  	}
   165  	return nil
   166  }
   167  
   168  func (c *ConnectionSource) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
   169  	// this is nasty, but required for the API tests currently.
   170  	if c.isTest && c.pool == nil {
   171  		return nil, pgx.ErrNoRows
   172  	}
   173  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   174  		return tx.tx.Query(ctx, sql, args...)
   175  	}
   176  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   177  		return conn.Query(ctx, sql, args...)
   178  	}
   179  	return c.pool.Query(ctx, sql, args...)
   180  }
   181  
   182  func (c *ConnectionSource) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
   183  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   184  		return tx.tx.QueryRow(ctx, sql, args...)
   185  	}
   186  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   187  		return conn.QueryRow(ctx, sql, args...)
   188  	}
   189  	return c.pool.QueryRow(ctx, sql, args...)
   190  }
   191  
   192  func (c *ConnectionSource) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) {
   193  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   194  		return tx.tx.QueryFunc(ctx, sql, args, scans, f)
   195  	}
   196  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   197  		return conn.QueryFunc(ctx, sql, args, scans, f)
   198  	}
   199  	return c.pool.QueryFunc(ctx, sql, args, scans, f)
   200  }
   201  
   202  func (c *ConnectionSource) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
   203  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   204  		return tx.tx.SendBatch(ctx, b)
   205  	}
   206  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   207  		return conn.SendBatch(ctx, b)
   208  	}
   209  	return c.pool.SendBatch(ctx, b)
   210  }
   211  
   212  func (c *ConnectionSource) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
   213  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   214  		return tx.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
   215  	}
   216  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   217  		return conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
   218  	}
   219  	return c.pool.CopyFrom(ctx, tableName, columnNames, rowSrc)
   220  }
   221  
   222  func (c *ConnectionSource) CopyTo(ctx context.Context, w io.Writer, sql string, args ...any) (pgconn.CommandTag, error) {
   223  	// this is nasty, but required for the API tests currently.
   224  	if c.isTest && c.pool == nil {
   225  		return pgconn.CommandTag{}, nil
   226  	}
   227  	var err error
   228  	sql, err = SanitizeSql(sql, args...)
   229  	if err != nil {
   230  		return nil, fmt.Errorf("failed to sanitize sql: %w", err)
   231  	}
   232  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   233  		return tx.tx.Conn().PgConn().CopyTo(ctx, w, sql)
   234  	}
   235  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   236  		return conn.PgConn().CopyTo(ctx, w, sql)
   237  	}
   238  	conn, err := c.pool.Acquire(ctx)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	defer conn.Release()
   243  	return conn.Conn().PgConn().CopyTo(ctx, w, sql)
   244  }
   245  
   246  func (c *ConnectionSource) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
   247  	if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
   248  		return tx.tx.Exec(ctx, sql, args...)
   249  	}
   250  	if conn, ok := ctx.Value(connKey{}).(*wrappedConn); ok {
   251  		return conn.Exec(ctx, sql, args...)
   252  	}
   253  	return c.pool.Exec(ctx, sql, args...)
   254  }
   255  
   256  type wrappedConn struct {
   257  	*pgx.Conn
   258  }
   259  
   260  func (c *ConnectionSource) RefreshMaterializedViews(ctx context.Context) error {
   261  	conn := ctx.Value(connKey{}).(*wrappedConn)
   262  	materializedViewsToRefresh := []struct {
   263  		name         string
   264  		concurrently bool
   265  	}{
   266  		{"game_stats", false},
   267  		{"game_stats_current", false},
   268  	}
   269  
   270  	for _, view := range materializedViewsToRefresh {
   271  		sql := "REFRESH MATERIALIZED VIEW "
   272  		if view.concurrently {
   273  			sql += "CONCURRENTLY "
   274  		}
   275  		sql += view.name
   276  
   277  		_, err := conn.Exec(ctx, sql)
   278  		if err != nil {
   279  			return fmt.Errorf("failed to refresh materialized view %s: %w", view.name, err)
   280  		}
   281  	}
   282  	return nil
   283  }
   284  
   285  func (c *ConnectionSource) Close() {
   286  	c.pool.Close()
   287  }
   288  
   289  func (c *ConnectionSource) wrapE(err error) error {
   290  	return wrapE(err)
   291  }
   292  
   293  func (t *wrappedTx) commit(ctx context.Context) ([]func(), error) {
   294  	// return callbacks so we only invoke them if no errors occurred
   295  	ret := t.postHooks
   296  	for id, sTx := range t.subTx {
   297  		// acquire the lock, release it as soon as possible
   298  		sTx.mu.Lock()
   299  		subCB, err := sTx.commit(ctx)
   300  		if err != nil {
   301  			sTx.mu.Unlock()
   302  			return nil, err
   303  		}
   304  		sTx.mu.Unlock()
   305  		delete(t.subTx, id)
   306  		// prepend callbacks from sub transactions
   307  		ret = append(subCB, ret...)
   308  	}
   309  	// actually commit this transaction
   310  	if err := t.tx.Commit(ctx); err != nil {
   311  		return nil, err
   312  	}
   313  	return ret, nil
   314  }
   315  
   316  func (t *wrappedTx) Rollback(ctx context.Context) error {
   317  	for _, sTx := range t.subTx {
   318  		if err := sTx.Rollback(ctx); err != nil {
   319  			return err
   320  		}
   321  	}
   322  	if err := t.tx.Rollback(ctx); err != nil {
   323  		return fmt.Errorf("failed to rollback transaction for context:%s, error:%w", ctx, err)
   324  	}
   325  	if t.parent != nil {
   326  		t.parent.rmSubTx(t.id)
   327  	}
   328  	return nil
   329  }
   330  
   331  func (t *wrappedTx) rmSubTx(id int64) {
   332  	t.mu.Lock()
   333  	defer t.mu.Unlock()
   334  	// this is called from Rollback, which is recursive already.
   335  	// no need to recursively remove the sub-tx
   336  	delete(t.subTx, id)
   337  }