github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/database/sqlcommon/sqlcommon.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package sqlcommon
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  
    23  	sq "github.com/Masterminds/squirrel"
    24  	"github.com/golang-migrate/migrate/v4"
    25  	"github.com/kaleido-io/firefly/internal/config"
    26  	"github.com/kaleido-io/firefly/internal/i18n"
    27  	"github.com/kaleido-io/firefly/internal/log"
    28  	"github.com/kaleido-io/firefly/pkg/database"
    29  	"github.com/kaleido-io/firefly/pkg/fftypes"
    30  
    31  	// Import migrate file source
    32  	_ "github.com/golang-migrate/migrate/v4/source/file"
    33  )
    34  
    35  type SQLCommon struct {
    36  	db           *sql.DB
    37  	capabilities *database.Capabilities
    38  	callbacks    database.Callbacks
    39  	provider     Provider
    40  }
    41  
    42  type txContextKey struct{}
    43  
    44  type txWrapper struct {
    45  	sqlTX      *sql.Tx
    46  	postCommit []func()
    47  }
    48  
    49  func (s *SQLCommon) Init(ctx context.Context, provider Provider, prefix config.Prefix, callbacks database.Callbacks, capabilities *database.Capabilities) (err error) {
    50  	s.capabilities = capabilities
    51  	s.callbacks = callbacks
    52  	s.provider = provider
    53  	if s.provider == nil || s.provider.PlaceholderFormat() == nil || s.provider.SequenceField("") == "" {
    54  		log.L(ctx).Errorf("Invalid SQL options from provider '%T'", s.provider)
    55  		return i18n.NewError(ctx, i18n.MsgDBInitFailed)
    56  	}
    57  
    58  	if s.db, err = provider.Open(prefix.GetString(SQLConfDatasourceURL)); err != nil {
    59  		return i18n.WrapError(ctx, err, i18n.MsgDBInitFailed)
    60  	}
    61  
    62  	if prefix.GetBool(SQLConfMigrationsAuto) {
    63  		if err = s.applyDBMigrations(ctx, prefix, provider); err != nil {
    64  			return i18n.WrapError(ctx, err, i18n.MsgDBMigrationFailed)
    65  		}
    66  	}
    67  
    68  	return nil
    69  }
    70  
    71  func (s *SQLCommon) Capabilities() *database.Capabilities { return s.capabilities }
    72  
    73  func (s *SQLCommon) RunAsGroup(ctx context.Context, fn func(ctx context.Context) error) error {
    74  	ctx, tx, _, err := s.beginOrUseTx(ctx)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	defer s.rollbackTx(ctx, tx, false /* we _are_ the auto-committer */)
    79  
    80  	if err = fn(ctx); err != nil {
    81  		return err
    82  	}
    83  
    84  	return s.commitTx(ctx, tx, false /* we _are_ the auto-committer */)
    85  }
    86  
    87  func (s *SQLCommon) applyDBMigrations(ctx context.Context, prefix config.Prefix, provider Provider) error {
    88  	driver, err := provider.GetMigrationDriver(s.db)
    89  	if err == nil {
    90  		var m *migrate.Migrate
    91  		m, err = migrate.NewWithDatabaseInstance(
    92  			"file://"+prefix.GetString(SQLConfMigrationsDirectory),
    93  			provider.Name(), driver)
    94  		if err == nil {
    95  			err = m.Up()
    96  		}
    97  	}
    98  	if err != nil && err != migrate.ErrNoChange {
    99  		return i18n.WrapError(ctx, err, i18n.MsgDBMigrationFailed)
   100  	}
   101  	return nil
   102  }
   103  
   104  func getTXFromContext(ctx context.Context) *txWrapper {
   105  	ctxKey := txContextKey{}
   106  	txi := ctx.Value(ctxKey)
   107  	if txi != nil {
   108  		if tx, ok := txi.(*txWrapper); ok {
   109  			return tx
   110  		}
   111  	}
   112  	return nil
   113  }
   114  
   115  func (s *SQLCommon) beginOrUseTx(ctx context.Context) (ctx1 context.Context, tx *txWrapper, autoCommit bool, err error) {
   116  
   117  	tx = getTXFromContext(ctx)
   118  	if tx != nil {
   119  		// There is s transaction on the context already.
   120  		// return existing with auto-commit flag, to prevent early commit
   121  		return ctx, tx, true, nil
   122  	}
   123  
   124  	l := log.L(ctx).WithField("dbtx", fftypes.ShortID())
   125  	ctx1 = log.WithLogger(ctx, l)
   126  	l.Debugf("SQL-> begin")
   127  	sqlTX, err := s.db.Begin()
   128  	if err != nil {
   129  		return ctx1, nil, false, i18n.WrapError(ctx1, err, i18n.MsgDBBeginFailed)
   130  	}
   131  	tx = &txWrapper{
   132  		sqlTX: sqlTX,
   133  	}
   134  	ctx1 = context.WithValue(ctx, txContextKey{}, tx)
   135  	l.Debugf("SQL<- begin")
   136  	return ctx1, tx, false, err
   137  }
   138  
   139  func (s *SQLCommon) queryTx(ctx context.Context, tx *txWrapper, q sq.SelectBuilder) (*sql.Rows, error) {
   140  	if tx == nil {
   141  		// If there is a transaction in the context, we should use it to provide consistency
   142  		// in the read operations (read after insert for example).
   143  		tx = getTXFromContext(ctx)
   144  	}
   145  
   146  	l := log.L(ctx)
   147  	sqlQuery, args, err := q.PlaceholderFormat(s.provider.PlaceholderFormat()).ToSql()
   148  	if err != nil {
   149  		return nil, i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed)
   150  	}
   151  	l.Debugf(`SQL-> query: %s`, sqlQuery)
   152  	l.Tracef(`SQL-> query args: %+v`, args)
   153  	var rows *sql.Rows
   154  	if tx != nil {
   155  		rows, err = tx.sqlTX.QueryContext(ctx, sqlQuery, args...)
   156  	} else {
   157  		rows, err = s.db.QueryContext(ctx, sqlQuery, args...)
   158  	}
   159  	if err != nil {
   160  		l.Errorf(`SQL query failed: %s sql=[ %s ]`, err, sqlQuery)
   161  		return nil, i18n.WrapError(ctx, err, i18n.MsgDBQueryFailed)
   162  	}
   163  	l.Debugf(`SQL<- query`)
   164  	return rows, nil
   165  }
   166  
   167  func (s *SQLCommon) query(ctx context.Context, q sq.SelectBuilder) (*sql.Rows, error) {
   168  	return s.queryTx(ctx, nil, q)
   169  }
   170  
   171  func (s *SQLCommon) insertTx(ctx context.Context, tx *txWrapper, q sq.InsertBuilder) (int64, error) {
   172  	l := log.L(ctx)
   173  	q, useQuery := s.provider.UpdateInsertForSequenceReturn(q)
   174  
   175  	sqlQuery, args, err := q.PlaceholderFormat(s.provider.PlaceholderFormat()).ToSql()
   176  	if err != nil {
   177  		return -1, i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed)
   178  	}
   179  	l.Debugf(`SQL-> insert: %s`, sqlQuery)
   180  	l.Tracef(`SQL-> insert args: %+v`, args)
   181  	var sequence int64
   182  	if useQuery {
   183  		err := tx.sqlTX.QueryRowContext(ctx, sqlQuery, args...).Scan(&sequence)
   184  		if err != nil {
   185  			l.Errorf(`SQL insert failed: %s sql=[ %s ]: %s`, err, sqlQuery, err)
   186  			return -1, i18n.WrapError(ctx, err, i18n.MsgDBInsertFailed)
   187  		}
   188  	} else {
   189  		res, err := tx.sqlTX.ExecContext(ctx, sqlQuery, args...)
   190  		if err != nil {
   191  			l.Errorf(`SQL insert failed: %s sql=[ %s ]: %s`, err, sqlQuery, err)
   192  			return -1, i18n.WrapError(ctx, err, i18n.MsgDBInsertFailed)
   193  		}
   194  		sequence, _ = res.LastInsertId()
   195  	}
   196  	l.Debugf(`SQL<- inserted sequence=%d`, sequence)
   197  	return sequence, nil
   198  }
   199  
   200  func (s *SQLCommon) deleteTx(ctx context.Context, tx *txWrapper, q sq.DeleteBuilder) error {
   201  	l := log.L(ctx)
   202  	sqlQuery, args, err := q.PlaceholderFormat(s.provider.PlaceholderFormat()).ToSql()
   203  	if err != nil {
   204  		return i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed)
   205  	}
   206  	l.Debugf(`SQL-> delete: %s`, sqlQuery)
   207  	l.Tracef(`SQL-> delete args: %+v`, args)
   208  	res, err := tx.sqlTX.ExecContext(ctx, sqlQuery, args...)
   209  	if err != nil {
   210  		l.Errorf(`SQL delete failed: %s sql=[ %s ]: %s`, err, sqlQuery, err)
   211  		return i18n.WrapError(ctx, err, i18n.MsgDBDeleteFailed)
   212  	}
   213  	ra, _ := res.RowsAffected()
   214  	l.Debugf(`SQL<- delete affected=%d`, ra)
   215  	if ra < 1 {
   216  		return database.DeleteRecordNotFound
   217  	}
   218  	return nil
   219  }
   220  
   221  func (s *SQLCommon) updateTx(ctx context.Context, tx *txWrapper, q sq.UpdateBuilder) error {
   222  	l := log.L(ctx)
   223  	sqlQuery, args, err := q.PlaceholderFormat(s.provider.PlaceholderFormat()).ToSql()
   224  	if err != nil {
   225  		return i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed)
   226  	}
   227  	l.Debugf(`SQL-> update: %s`, sqlQuery)
   228  	l.Tracef(`SQL-> update args: %+v`, args)
   229  	res, err := tx.sqlTX.ExecContext(ctx, sqlQuery, args...)
   230  	if err != nil {
   231  		l.Errorf(`SQL update failed: %s sql=[ %s ]`, err, sqlQuery)
   232  		return i18n.WrapError(ctx, err, i18n.MsgDBUpdateFailed)
   233  	}
   234  	ra, _ := res.RowsAffected() // currently only used for debugging
   235  	l.Debugf(`SQL<- update affected=%d`, ra)
   236  	return nil
   237  }
   238  
   239  func (s *SQLCommon) postCommitEvent(tx *txWrapper, fn func()) {
   240  	tx.postCommit = append(tx.postCommit, fn)
   241  }
   242  
   243  // rollbackTx be safely called as a defer, as it is a cheap no-op if the transaction is complete
   244  func (s *SQLCommon) rollbackTx(ctx context.Context, tx *txWrapper, autoCommit bool) {
   245  	if autoCommit {
   246  		// We're inside of a wide transaction boundary with an auto-commit
   247  		return
   248  	}
   249  
   250  	err := tx.sqlTX.Rollback()
   251  	if err == nil {
   252  		log.L(ctx).Warnf("SQL! transaction rollback")
   253  	}
   254  	if err != nil && err != sql.ErrTxDone {
   255  		log.L(ctx).Errorf(`SQL rollback failed: %s`, err)
   256  	}
   257  }
   258  
   259  func (s *SQLCommon) commitTx(ctx context.Context, tx *txWrapper, autoCommit bool) error {
   260  	if autoCommit {
   261  		// We're inside of a wide transaction boundary with an auto-commit
   262  		return nil
   263  	}
   264  
   265  	l := log.L(ctx)
   266  	l.Debugf(`SQL-> commit`)
   267  	err := tx.sqlTX.Commit()
   268  	if err != nil {
   269  		l.Errorf(`SQL commit failed: %s`, err)
   270  		return i18n.WrapError(ctx, err, i18n.MsgDBCommitFailed)
   271  	}
   272  	l.Debugf(`SQL<- commit`)
   273  
   274  	// Emit any post commit events (these aren't currently allowed to cause errors)
   275  	for i, pce := range tx.postCommit {
   276  		l.Tracef(`-> post commit event %d`, i)
   277  		pce()
   278  		l.Tracef(`<- post commit event %d`, i)
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  func (s *SQLCommon) DB() *sql.DB {
   285  	return s.db
   286  }
   287  
   288  func (s *SQLCommon) Close() {
   289  	if s.db != nil {
   290  		err := s.db.Close()
   291  		log.L(context.Background()).Debugf("Database closed (err=%v)", err)
   292  	}
   293  }