github.com/erda-project/erda-infra@v1.0.9/providers/mysql/v2/tx.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package v2
    16  
    17  import (
    18  	"github.com/pkg/errors"
    19  	"gorm.io/gorm"
    20  )
    21  
    22  // ErrInvalidTransaction means the transaction is alread committed or roll backed
    23  var ErrInvalidTransaction = errors.New("invalid transaction, it is already committed or roll backed")
    24  
    25  // TX contains the CRUS APIs
    26  type TX struct {
    27  	Error error
    28  
    29  	db    *gorm.DB
    30  	inTx  bool
    31  	valid bool
    32  }
    33  
    34  // NewTx returns a *TX
    35  func NewTx(db *gorm.DB) *TX {
    36  	return &TX{db: db, valid: true}
    37  }
    38  
    39  // Create inserts a row
    40  func (tx *TX) Create(i interface{}) error {
    41  	if tx.inTx && !tx.valid {
    42  		return ErrInvalidTransaction
    43  	}
    44  	tx.Error = tx.db.Create(i).Error
    45  	return tx.Error
    46  }
    47  
    48  // CreateInBatches inserts multi rows
    49  func (tx *TX) CreateInBatches(i interface{}, size int) error {
    50  	if tx.inTx && !tx.valid {
    51  		return ErrInvalidTransaction
    52  	}
    53  	tx.Error = tx.db.CreateInBatches(i, size).Error
    54  	return tx.Error
    55  }
    56  
    57  // Delete deletes rows with conditions
    58  func (tx *TX) Delete(i interface{}, options ...Option) (int64, error) {
    59  	if tx.inTx && !tx.valid {
    60  		return 0, ErrInvalidTransaction
    61  	}
    62  	var db = tx.DB()
    63  	for _, opt := range options {
    64  		db = opt(db)
    65  	}
    66  	db = db.Delete(i)
    67  	return db.RowsAffected, db.Error
    68  }
    69  
    70  // Updates updates the model i with the given value. v can be a map or a model struct.
    71  // options is conditions.
    72  func (tx *TX) Updates(i, v interface{}, options ...Option) error {
    73  	if tx.inTx && !tx.valid {
    74  		return ErrInvalidTransaction
    75  	}
    76  	var db = tx.DB()
    77  	for _, opt := range options {
    78  		db = opt(db)
    79  	}
    80  	return db.Model(i).Updates(v).Error
    81  }
    82  
    83  // SetColumns is used to set columns.
    84  // At least one SetColumn Option in the options.
    85  func (tx *TX) SetColumns(i interface{}, options ...Option) error {
    86  	if tx.inTx && !tx.valid {
    87  		return ErrInvalidTransaction
    88  	}
    89  	var db = tx.DB()
    90  	db = db.Model(i)
    91  	for _, opt := range options {
    92  		db = opt(db)
    93  	}
    94  	return db.Error
    95  }
    96  
    97  // List lists records.
    98  func (tx *TX) List(i interface{}, options ...Option) (int64, error) {
    99  	var total int64
   100  	var db = tx.DB()
   101  	for _, opt := range options {
   102  		db = opt(db)
   103  	}
   104  
   105  	err := db.Find(i).Count(&total).Error
   106  	if err == nil {
   107  		return total, nil
   108  	}
   109  	if errors.Is(err, gorm.ErrRecordNotFound) {
   110  		return 0, nil
   111  	}
   112  	return 0, err
   113  }
   114  
   115  // Get gets the first record.
   116  func (tx *TX) Get(i interface{}, options ...Option) (bool, error) {
   117  	var db = tx.DB()
   118  	for _, opt := range options {
   119  		db = opt(db)
   120  	}
   121  
   122  	err := db.First(i).Error
   123  	if err == nil {
   124  		return true, nil
   125  	}
   126  	if errors.Is(err, gorm.ErrRecordNotFound) {
   127  		return false, nil
   128  	}
   129  	return false, err
   130  }
   131  
   132  // Commit commits the transaction.
   133  func (tx *TX) Commit() error {
   134  	if !tx.inTx {
   135  		return errors.New("not in transaction")
   136  	}
   137  	if !tx.valid {
   138  		return ErrInvalidTransaction
   139  	}
   140  	if tx.Error != nil {
   141  		return errors.Wrap(tx.Error, "can not commit with error")
   142  	}
   143  	tx.db.Commit()
   144  	tx.valid = false
   145  	return nil
   146  }
   147  
   148  // Rollback rollbacks the transaction.
   149  func (tx *TX) Rollback() error {
   150  	if !tx.inTx {
   151  		return errors.New("not in transaction")
   152  	}
   153  	if !tx.valid {
   154  		return ErrInvalidTransaction
   155  	}
   156  	tx.db.Rollback()
   157  	tx.valid = false
   158  	return nil
   159  }
   160  
   161  // CommitOrRollback commits the transaction if db.Error is nil,
   162  // or rollbacks if the db.Error is not nil.
   163  func (tx *TX) CommitOrRollback() {
   164  	if tx.inTx && !tx.valid {
   165  		return
   166  	}
   167  	if tx.Error == nil {
   168  		tx.db.Commit()
   169  	} else {
   170  		tx.db.Rollback()
   171  	}
   172  	tx.valid = false
   173  }
   174  
   175  // DB returns the raw *gorm.DB
   176  func (tx *TX) DB() *gorm.DB {
   177  	return tx.db
   178  }