github.com/gogf/gf@v1.16.9/database/gdb/gdb_core_transaction.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/gogf/gf.
     6  
     7  package gdb
     8  
     9  import (
    10  	"context"
    11  	"database/sql"
    12  	"fmt"
    13  	"reflect"
    14  
    15  	"github.com/gogf/gf/container/gtype"
    16  	"github.com/gogf/gf/os/gtime"
    17  	"github.com/gogf/gf/util/gconv"
    18  	"github.com/gogf/gf/util/guid"
    19  
    20  	"github.com/gogf/gf/text/gregex"
    21  )
    22  
    23  // TX is the struct for transaction management.
    24  type TX struct {
    25  	db               DB              // db is the current gdb database manager.
    26  	tx               *sql.Tx         // tx is the raw and underlying transaction manager.
    27  	ctx              context.Context // ctx is the context for this transaction only.
    28  	master           *sql.DB         // master is the raw and underlying database manager.
    29  	transactionId    string          // transactionId is an unique id generated by this object for this transaction.
    30  	transactionCount int             // transactionCount marks the times that Begins.
    31  	isClosed         bool            // isClosed marks this transaction has already been committed or rolled back.
    32  }
    33  
    34  const (
    35  	transactionPointerPrefix    = "transaction"
    36  	contextTransactionKeyPrefix = "TransactionObjectForGroup_"
    37  	transactionIdForLoggerCtx   = "TransactionId"
    38  )
    39  
    40  var (
    41  	transactionIdGenerator = gtype.NewUint64()
    42  )
    43  
    44  // Begin starts and returns the transaction object.
    45  // You should call Commit or Rollback functions of the transaction object
    46  // if you no longer use the transaction. Commit or Rollback functions will also
    47  // close the transaction automatically.
    48  func (c *Core) Begin() (tx *TX, err error) {
    49  	return c.doBeginCtx(c.GetCtx())
    50  }
    51  
    52  func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) {
    53  	if master, err := c.db.Master(); err != nil {
    54  		return nil, err
    55  	} else {
    56  		var (
    57  			tx         *TX
    58  			sqlStr     = "BEGIN"
    59  			mTime1     = gtime.TimestampMilli()
    60  			rawTx, err = master.Begin()
    61  			mTime2     = gtime.TimestampMilli()
    62  			sqlObj     = &Sql{
    63  				Sql:           sqlStr,
    64  				Type:          "DB.Begin",
    65  				Args:          nil,
    66  				Format:        sqlStr,
    67  				Error:         err,
    68  				Start:         mTime1,
    69  				End:           mTime2,
    70  				Group:         c.db.GetGroup(),
    71  				IsTransaction: true,
    72  			}
    73  		)
    74  		if err == nil {
    75  			tx = &TX{
    76  				db:            c.db,
    77  				tx:            rawTx,
    78  				ctx:           context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)),
    79  				master:        master,
    80  				transactionId: guid.S(),
    81  			}
    82  			ctx = tx.ctx
    83  		}
    84  		// Tracing and logging.
    85  		c.addSqlToTracing(ctx, sqlObj)
    86  		if c.db.GetDebug() {
    87  			c.writeSqlToLogger(ctx, sqlObj)
    88  		}
    89  		return tx, err
    90  	}
    91  }
    92  
    93  // Transaction wraps the transaction logic using function `f`.
    94  // It rollbacks the transaction and returns the error from function `f` if
    95  // it returns non-nil error. It commits the transaction and returns nil if
    96  // function `f` returns nil.
    97  //
    98  // Note that, you should not Commit or Rollback the transaction in function `f`
    99  // as it is automatically handled by this function.
   100  func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
   101  	var (
   102  		tx *TX
   103  	)
   104  	if ctx == nil {
   105  		ctx = c.GetCtx()
   106  	}
   107  	// Check transaction object from context.
   108  	tx = TXFromCtx(ctx, c.db.GetGroup())
   109  	if tx != nil {
   110  		return tx.Transaction(ctx, f)
   111  	}
   112  	tx, err = c.doBeginCtx(ctx)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	// Inject transaction object into context.
   117  	tx.ctx = WithTX(tx.ctx, tx)
   118  	defer func() {
   119  		if err == nil {
   120  			if e := recover(); e != nil {
   121  				err = fmt.Errorf("%v", e)
   122  			}
   123  		}
   124  		if err != nil {
   125  			if e := tx.Rollback(); e != nil {
   126  				err = e
   127  			}
   128  		} else {
   129  			if e := tx.Commit(); e != nil {
   130  				err = e
   131  			}
   132  		}
   133  	}()
   134  	err = f(tx.ctx, tx)
   135  	return
   136  }
   137  
   138  // WithTX injects given transaction object into context and returns a new context.
   139  func WithTX(ctx context.Context, tx *TX) context.Context {
   140  	if tx == nil {
   141  		return ctx
   142  	}
   143  	// Check repeat injection from given.
   144  	group := tx.db.GetGroup()
   145  	if tx := TXFromCtx(ctx, group); tx != nil && tx.db.GetGroup() == group {
   146  		return ctx
   147  	}
   148  	dbCtx := tx.db.GetCtx()
   149  	if tx := TXFromCtx(dbCtx, group); tx != nil && tx.db.GetGroup() == group {
   150  		return dbCtx
   151  	}
   152  	// Inject transaction object and id into context.
   153  	ctx = context.WithValue(ctx, transactionKeyForContext(group), tx)
   154  	return ctx
   155  }
   156  
   157  // TXFromCtx retrieves and returns transaction object from context.
   158  // It is usually used in nested transaction feature, and it returns nil if it is not set previously.
   159  func TXFromCtx(ctx context.Context, group string) *TX {
   160  	if ctx == nil {
   161  		return nil
   162  	}
   163  	v := ctx.Value(transactionKeyForContext(group))
   164  	if v != nil {
   165  		tx := v.(*TX)
   166  		if tx.IsClosed() {
   167  			return nil
   168  		}
   169  		tx.ctx = ctx
   170  		return tx
   171  	}
   172  	return nil
   173  }
   174  
   175  // transactionKeyForContext forms and returns a string for storing transaction object of certain database group into context.
   176  func transactionKeyForContext(group string) string {
   177  	return contextTransactionKeyPrefix + group
   178  }
   179  
   180  // transactionKeyForNestedPoint forms and returns the transaction key at current save point.
   181  func (tx *TX) transactionKeyForNestedPoint() string {
   182  	return tx.db.GetCore().QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount))
   183  }
   184  
   185  // Ctx sets the context for current transaction.
   186  func (tx *TX) Ctx(ctx context.Context) *TX {
   187  	tx.ctx = ctx
   188  	return tx
   189  }
   190  
   191  // Commit commits current transaction.
   192  // Note that it releases previous saved transaction point if it's in a nested transaction procedure,
   193  // or else it commits the hole transaction.
   194  func (tx *TX) Commit() error {
   195  	if tx.transactionCount > 0 {
   196  		tx.transactionCount--
   197  		_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint())
   198  		return err
   199  	}
   200  	var (
   201  		sqlStr = "COMMIT"
   202  		mTime1 = gtime.TimestampMilli()
   203  		err    = tx.tx.Commit()
   204  		mTime2 = gtime.TimestampMilli()
   205  		sqlObj = &Sql{
   206  			Sql:           sqlStr,
   207  			Type:          "TX.Commit",
   208  			Args:          nil,
   209  			Format:        sqlStr,
   210  			Error:         err,
   211  			Start:         mTime1,
   212  			End:           mTime2,
   213  			Group:         tx.db.GetGroup(),
   214  			IsTransaction: true,
   215  		}
   216  	)
   217  	tx.isClosed = true
   218  	tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj)
   219  	if tx.db.GetDebug() {
   220  		tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj)
   221  	}
   222  	return err
   223  }
   224  
   225  // Rollback aborts current transaction.
   226  // Note that it aborts current transaction if it's in a nested transaction procedure,
   227  // or else it aborts the hole transaction.
   228  func (tx *TX) Rollback() error {
   229  	if tx.transactionCount > 0 {
   230  		tx.transactionCount--
   231  		_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint())
   232  		return err
   233  	}
   234  	var (
   235  		sqlStr = "ROLLBACK"
   236  		mTime1 = gtime.TimestampMilli()
   237  		err    = tx.tx.Rollback()
   238  		mTime2 = gtime.TimestampMilli()
   239  		sqlObj = &Sql{
   240  			Sql:           sqlStr,
   241  			Type:          "TX.Rollback",
   242  			Args:          nil,
   243  			Format:        sqlStr,
   244  			Error:         err,
   245  			Start:         mTime1,
   246  			End:           mTime2,
   247  			Group:         tx.db.GetGroup(),
   248  			IsTransaction: true,
   249  		}
   250  	)
   251  	tx.isClosed = true
   252  	tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj)
   253  	if tx.db.GetDebug() {
   254  		tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj)
   255  	}
   256  	return err
   257  }
   258  
   259  // IsClosed checks and returns this transaction has already been committed or rolled back.
   260  func (tx *TX) IsClosed() bool {
   261  	return tx.isClosed
   262  }
   263  
   264  // Begin starts a nested transaction procedure.
   265  func (tx *TX) Begin() error {
   266  	_, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint())
   267  	if err != nil {
   268  		return err
   269  	}
   270  	tx.transactionCount++
   271  	return nil
   272  }
   273  
   274  // SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point.
   275  // The parameter `point` specifies the point name that will be saved to server.
   276  func (tx *TX) SavePoint(point string) error {
   277  	_, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
   278  	return err
   279  }
   280  
   281  // RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction.
   282  // The parameter `point` specifies the point name that was saved previously.
   283  func (tx *TX) RollbackTo(point string) error {
   284  	_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
   285  	return err
   286  }
   287  
   288  // Transaction wraps the transaction logic using function `f`.
   289  // It rollbacks the transaction and returns the error from function `f` if
   290  // it returns non-nil error. It commits the transaction and returns nil if
   291  // function `f` returns nil.
   292  //
   293  // Note that, you should not Commit or Rollback the transaction in function `f`
   294  // as it is automatically handled by this function.
   295  func (tx *TX) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
   296  	if ctx != nil {
   297  		tx.ctx = ctx
   298  	}
   299  	// Check transaction object from context.
   300  	if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil {
   301  		// Inject transaction object into context.
   302  		tx.ctx = WithTX(tx.ctx, tx)
   303  	}
   304  	err = tx.Begin()
   305  	if err != nil {
   306  		return err
   307  	}
   308  	defer func() {
   309  		if err == nil {
   310  			if e := recover(); e != nil {
   311  				err = fmt.Errorf("%v", e)
   312  			}
   313  		}
   314  		if err != nil {
   315  			if e := tx.Rollback(); e != nil {
   316  				err = e
   317  			}
   318  		} else {
   319  			if e := tx.Commit(); e != nil {
   320  				err = e
   321  			}
   322  		}
   323  	}()
   324  	err = f(tx.ctx, tx)
   325  	return
   326  }
   327  
   328  // Query does query operation on transaction.
   329  // See Core.Query.
   330  func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
   331  	return tx.db.DoQuery(tx.ctx, &txLink{tx.tx}, sql, args...)
   332  }
   333  
   334  // Exec does none query operation on transaction.
   335  // See Core.Exec.
   336  func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) {
   337  	return tx.db.DoExec(tx.ctx, &txLink{tx.tx}, sql, args...)
   338  }
   339  
   340  // Prepare creates a prepared statement for later queries or executions.
   341  // Multiple queries or executions may be run concurrently from the
   342  // returned statement.
   343  // The caller must call the statement's Close method
   344  // when the statement is no longer needed.
   345  func (tx *TX) Prepare(sql string) (*Stmt, error) {
   346  	return tx.db.DoPrepare(tx.ctx, &txLink{tx.tx}, sql)
   347  }
   348  
   349  // GetAll queries and returns data records from database.
   350  func (tx *TX) GetAll(sql string, args ...interface{}) (Result, error) {
   351  	rows, err := tx.Query(sql, args...)
   352  	if err != nil || rows == nil {
   353  		return nil, err
   354  	}
   355  	defer rows.Close()
   356  	return tx.db.GetCore().convertRowsToResult(rows)
   357  }
   358  
   359  // GetOne queries and returns one record from database.
   360  func (tx *TX) GetOne(sql string, args ...interface{}) (Record, error) {
   361  	list, err := tx.GetAll(sql, args...)
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  	if len(list) > 0 {
   366  		return list[0], nil
   367  	}
   368  	return nil, nil
   369  }
   370  
   371  // GetStruct queries one record from database and converts it to given struct.
   372  // The parameter `pointer` should be a pointer to struct.
   373  func (tx *TX) GetStruct(obj interface{}, sql string, args ...interface{}) error {
   374  	one, err := tx.GetOne(sql, args...)
   375  	if err != nil {
   376  		return err
   377  	}
   378  	return one.Struct(obj)
   379  }
   380  
   381  // GetStructs queries records from database and converts them to given struct.
   382  // The parameter `pointer` should be type of struct slice: []struct/[]*struct.
   383  func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error {
   384  	all, err := tx.GetAll(sql, args...)
   385  	if err != nil {
   386  		return err
   387  	}
   388  	return all.Structs(objPointerSlice)
   389  }
   390  
   391  // GetScan queries one or more records from database and converts them to given struct or
   392  // struct array.
   393  //
   394  // If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
   395  // the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
   396  // for conversion.
   397  func (tx *TX) GetScan(pointer interface{}, sql string, args ...interface{}) error {
   398  	t := reflect.TypeOf(pointer)
   399  	k := t.Kind()
   400  	if k != reflect.Ptr {
   401  		return fmt.Errorf("params should be type of pointer, but got: %v", k)
   402  	}
   403  	k = t.Elem().Kind()
   404  	switch k {
   405  	case reflect.Array, reflect.Slice:
   406  		return tx.GetStructs(pointer, sql, args...)
   407  	case reflect.Struct:
   408  		return tx.GetStruct(pointer, sql, args...)
   409  	default:
   410  		return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
   411  	}
   412  }
   413  
   414  // GetValue queries and returns the field value from database.
   415  // The sql should queries only one field from database, or else it returns only one
   416  // field of the result.
   417  func (tx *TX) GetValue(sql string, args ...interface{}) (Value, error) {
   418  	one, err := tx.GetOne(sql, args...)
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  	for _, v := range one {
   423  		return v, nil
   424  	}
   425  	return nil, nil
   426  }
   427  
   428  // GetCount queries and returns the count from database.
   429  func (tx *TX) GetCount(sql string, args ...interface{}) (int, error) {
   430  	if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
   431  		sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
   432  	}
   433  	value, err := tx.GetValue(sql, args...)
   434  	if err != nil {
   435  		return 0, err
   436  	}
   437  	return value.Int(), nil
   438  }
   439  
   440  // Insert does "INSERT INTO ..." statement for the table.
   441  // If there's already one unique record of the data in the table, it returns error.
   442  //
   443  // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
   444  // Eg:
   445  // Data(g.Map{"uid": 10000, "name":"john"})
   446  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
   447  //
   448  // The parameter `batch` specifies the batch operation count when given data is slice.
   449  func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
   450  	if len(batch) > 0 {
   451  		return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert()
   452  	}
   453  	return tx.Model(table).Ctx(tx.ctx).Data(data).Insert()
   454  }
   455  
   456  // InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
   457  // If there's already one unique record of the data in the table, it ignores the inserting.
   458  //
   459  // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
   460  // Eg:
   461  // Data(g.Map{"uid": 10000, "name":"john"})
   462  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
   463  //
   464  // The parameter `batch` specifies the batch operation count when given data is slice.
   465  func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
   466  	if len(batch) > 0 {
   467  		return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore()
   468  	}
   469  	return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore()
   470  }
   471  
   472  // InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
   473  func (tx *TX) InsertAndGetId(table string, data interface{}, batch ...int) (int64, error) {
   474  	if len(batch) > 0 {
   475  		return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId()
   476  	}
   477  	return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId()
   478  }
   479  
   480  // Replace does "REPLACE INTO ..." statement for the table.
   481  // If there's already one unique record of the data in the table, it deletes the record
   482  // and inserts a new one.
   483  //
   484  // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
   485  // Eg:
   486  // Data(g.Map{"uid": 10000, "name":"john"})
   487  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
   488  //
   489  // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
   490  // If given data is type of slice, it then does batch replacing, and the optional parameter
   491  // `batch` specifies the batch operation count.
   492  func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
   493  	if len(batch) > 0 {
   494  		return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace()
   495  	}
   496  	return tx.Model(table).Ctx(tx.ctx).Data(data).Replace()
   497  }
   498  
   499  // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
   500  // It updates the record if there's primary or unique index in the saving data,
   501  // or else it inserts a new record into the table.
   502  //
   503  // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
   504  // Eg:
   505  // Data(g.Map{"uid": 10000, "name":"john"})
   506  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
   507  //
   508  // If given data is type of slice, it then does batch saving, and the optional parameter
   509  // `batch` specifies the batch operation count.
   510  func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
   511  	if len(batch) > 0 {
   512  		return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save()
   513  	}
   514  	return tx.Model(table).Ctx(tx.ctx).Data(data).Save()
   515  }
   516  
   517  // Update does "UPDATE ... " statement for the table.
   518  //
   519  // The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
   520  // Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"}
   521  //
   522  // The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
   523  // It is commonly used with parameter `args`.
   524  // Eg:
   525  // "uid=10000",
   526  // "uid", 10000
   527  // "money>? AND name like ?", 99999, "vip_%"
   528  // "status IN (?)", g.Slice{1,2,3}
   529  // "age IN(?,?)", 18, 50
   530  // User{ Id : 1, UserName : "john"}
   531  func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
   532  	return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update()
   533  }
   534  
   535  // Delete does "DELETE FROM ... " statement for the table.
   536  //
   537  // The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
   538  // It is commonly used with parameter `args`.
   539  // Eg:
   540  // "uid=10000",
   541  // "uid", 10000
   542  // "money>? AND name like ?", 99999, "vip_%"
   543  // "status IN (?)", g.Slice{1,2,3}
   544  // "age IN(?,?)", 18, 50
   545  // User{ Id : 1, UserName : "john"}
   546  func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
   547  	return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
   548  }