github.com/kaydxh/golang@v0.0.131/pkg/database/mysql/mysql_transaction.go (about)

     1  /*
     2   *Copyright (c) 2022, kaydxh
     3   *
     4   *Permission is hereby granted, free of charge, to any person obtaining a copy
     5   *of this software and associated documentation files (the "Software"), to deal
     6   *in the Software without restriction, including without limitation the rights
     7   *to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     8   *copies of the Software, and to permit persons to whom the Software is
     9   *furnished to do so, subject to the following conditions:
    10   *
    11   *The above copyright notice and this permission notice shall be included in all
    12   *copies or substantial portions of the Software.
    13   *
    14   *THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    15   *IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    16   *FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    17   *AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    18   *LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    19   *OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    20   *SOFTWARE.
    21   */
    22  package mysql
    23  
    24  import (
    25  	"context"
    26  	"database/sql"
    27  	"fmt"
    28  
    29  	"github.com/jmoiron/sqlx"
    30  	"github.com/sirupsen/logrus"
    31  )
    32  
    33  type TxDao struct {
    34  	*sqlx.Tx
    35  }
    36  
    37  func (d *TxDao) Begin(ctx context.Context, db *sqlx.DB, opts *sql.TxOptions) error {
    38  	if db == nil {
    39  		return fmt.Errorf("unexpected err: db is nil")
    40  	}
    41  
    42  	tx, err := db.BeginTxx(ctx, opts)
    43  	if err != nil {
    44  		return err
    45  	}
    46  
    47  	d.Tx = tx
    48  
    49  	return nil
    50  }
    51  
    52  func (d *TxDao) Commit() error {
    53  	if d.Tx == nil {
    54  		return fmt.Errorf("unexpected err: tx is nil")
    55  	}
    56  
    57  	err := d.Tx.Commit()
    58  	if err != nil {
    59  		return err
    60  	}
    61  	return nil
    62  }
    63  
    64  // Rollback ...
    65  func (d *TxDao) Rollback() error {
    66  	if d.Tx == nil {
    67  		return fmt.Errorf("unexpected err: tx is nil")
    68  	}
    69  
    70  	err := d.Tx.Rollback()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	return nil
    75  }
    76  
    77  func TxPipelined(ctx context.Context, db *sqlx.DB, fn func(*sqlx.Tx) error) (err error) {
    78  	var tx TxDao
    79  	err = tx.Begin(ctx, db, nil)
    80  	if err != nil {
    81  		logrus.WithError(err).Errorf("failed to transaction begin")
    82  		return err
    83  	}
    84  
    85  	defer func() {
    86  		if err != nil {
    87  			if txErr := tx.Rollback(); txErr != nil {
    88  				logrus.WithError(err).Errorf("failed to rollback, err: %v", txErr)
    89  				return
    90  			}
    91  			return
    92  		}
    93  
    94  		if err = tx.Commit(); err != nil {
    95  			logrus.WithError(err).Errorf("failed to commit")
    96  			return
    97  		}
    98  
    99  	}()
   100  
   101  	return fn(tx.Tx)
   102  }