github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/db/db.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/NpoolPlatform/go-service-framework/pkg/logger"
     8  
     9  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    10  
    11  	"entgo.io/ent/dialect"
    12  	entsql "entgo.io/ent/dialect/sql"
    13  	"github.com/NpoolPlatform/go-service-framework/pkg/mysql"
    14  
    15  	// ent policy runtime
    16  	_ "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/runtime"
    17  )
    18  
    19  func client() (*ent.Client, error) {
    20  	conn, err := mysql.GetConn()
    21  	if err != nil {
    22  		return nil, err
    23  	}
    24  	drv := entsql.OpenDB(dialect.MySQL, conn)
    25  	return ent.NewClient(ent.Driver(drv)), nil
    26  }
    27  
    28  func Init() error {
    29  	cli, err := client()
    30  	if err != nil {
    31  		return err
    32  	}
    33  	return cli.Schema.Create(context.Background())
    34  }
    35  
    36  func Client() (*ent.Client, error) {
    37  	return client()
    38  }
    39  
    40  func WithTx(ctx context.Context, fn func(ctx context.Context, tx *ent.Tx) error) error {
    41  	cli, err := Client()
    42  	if err != nil {
    43  		return err
    44  	}
    45  
    46  	tx, err := cli.Tx(ctx)
    47  	if err != nil {
    48  		return fmt.Errorf("fail get client transaction: %v", err)
    49  	}
    50  
    51  	succ := false
    52  	defer func() {
    53  		if !succ {
    54  			err := tx.Rollback()
    55  			if err != nil {
    56  				logger.Sugar().Errorf("fail rollback: %v", err)
    57  				return
    58  			}
    59  		}
    60  	}()
    61  
    62  	if err := fn(ctx, tx); err != nil {
    63  		return err
    64  	}
    65  
    66  	if err := tx.Commit(); err != nil {
    67  		return fmt.Errorf("committing transaction: %v", err)
    68  	}
    69  
    70  	succ = true
    71  	return nil
    72  }
    73  
    74  func WithClient(ctx context.Context, fn func(ctx context.Context, cli *ent.Client) error) error {
    75  	cli, err := Client()
    76  	if err != nil {
    77  		return fmt.Errorf("fail get db client: %v", err)
    78  	}
    79  
    80  	if err := fn(ctx, cli); err != nil {
    81  		return err
    82  	}
    83  	return nil
    84  }