github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/infrastructure/mysql/db.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  
     7  	"github.com/oinume/lekcije/backend/domain/repository"
     8  )
     9  
    10  type dbRepository struct {
    11  	db *sql.DB
    12  }
    13  
    14  func NewDBRepository(db *sql.DB) repository.DB {
    15  	return &dbRepository{db: db}
    16  }
    17  
    18  func (r *dbRepository) Transaction(ctx context.Context, f func(exec repository.Executor) error) error {
    19  	return r.TransactionWithOptions(ctx, &sql.TxOptions{}, f)
    20  }
    21  
    22  func (r *dbRepository) TransactionWithOptions(ctx context.Context, opts *sql.TxOptions, f func(exec repository.Executor) error) error {
    23  	return transactionWithOptions(ctx, r.db, opts, f)
    24  }
    25  
    26  func transaction(ctx context.Context, db *sql.DB, f func(exec repository.Executor) error) error {
    27  	return transactionWithOptions(ctx, db, &sql.TxOptions{}, f)
    28  }
    29  
    30  func transactionWithOptions(
    31  	ctx context.Context,
    32  	db *sql.DB,
    33  	opts *sql.TxOptions,
    34  	f func(exec repository.Executor) error,
    35  ) error {
    36  	tx, err := db.BeginTx(ctx, opts)
    37  	if err != nil {
    38  		return err
    39  	}
    40  	if err := f(tx); err != nil {
    41  		if err := tx.Rollback(); err != nil {
    42  			return err
    43  		}
    44  		return err
    45  	}
    46  	return tx.Commit()
    47  }