github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/db/db.go (about) 1 package db 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/NpoolPlatform/go-service-framework/pkg/logger" 8 9 "github.com/NpoolPlatform/chain-middleware/pkg/db/ent" 10 11 "entgo.io/ent/dialect" 12 entsql "entgo.io/ent/dialect/sql" 13 "github.com/NpoolPlatform/go-service-framework/pkg/mysql" 14 15 // ent policy runtime 16 _ "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/runtime" 17 ) 18 19 func client() (*ent.Client, error) { 20 conn, err := mysql.GetConn() 21 if err != nil { 22 return nil, err 23 } 24 drv := entsql.OpenDB(dialect.MySQL, conn) 25 return ent.NewClient(ent.Driver(drv)), nil 26 } 27 28 func Init() error { 29 cli, err := client() 30 if err != nil { 31 return err 32 } 33 return cli.Schema.Create(context.Background()) 34 } 35 36 func Client() (*ent.Client, error) { 37 return client() 38 } 39 40 func WithTx(ctx context.Context, fn func(ctx context.Context, tx *ent.Tx) error) error { 41 cli, err := Client() 42 if err != nil { 43 return err 44 } 45 46 tx, err := cli.Tx(ctx) 47 if err != nil { 48 return fmt.Errorf("fail get client transaction: %v", err) 49 } 50 51 succ := false 52 defer func() { 53 if !succ { 54 err := tx.Rollback() 55 if err != nil { 56 logger.Sugar().Errorf("fail rollback: %v", err) 57 return 58 } 59 } 60 }() 61 62 if err := fn(ctx, tx); err != nil { 63 return err 64 } 65 66 if err := tx.Commit(); err != nil { 67 return fmt.Errorf("committing transaction: %v", err) 68 } 69 70 succ = true 71 return nil 72 } 73 74 func WithClient(ctx context.Context, fn func(ctx context.Context, cli *ent.Client) error) error { 75 cli, err := Client() 76 if err != nil { 77 return fmt.Errorf("fail get db client: %v", err) 78 } 79 80 if err := fn(ctx, cli); err != nil { 81 return err 82 } 83 return nil 84 }