github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/tx/query.go (about) 1 package tx 2 3 import ( 4 "context" 5 "fmt" 6 7 basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1" 8 npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/tx" 9 10 "github.com/NpoolPlatform/chain-middleware/pkg/db" 11 "github.com/NpoolPlatform/chain-middleware/pkg/db/ent" 12 enttx "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/tran" 13 14 txcrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/tx" 15 entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase" 16 17 "entgo.io/ent/dialect/sql" 18 "github.com/shopspring/decimal" 19 ) 20 21 type queryHandler struct { 22 *Handler 23 stm *ent.TranSelect 24 infos []*npool.Tx 25 total uint32 26 } 27 28 func (h *queryHandler) selectTx(stm *ent.TranQuery) { 29 h.stm = stm.Select( 30 enttx.FieldID, 31 enttx.FieldEntID, 32 enttx.FieldCoinTypeID, 33 enttx.FieldFromAccountID, 34 enttx.FieldToAccountID, 35 enttx.FieldAmount, 36 enttx.FieldFeeAmount, 37 enttx.FieldState, 38 enttx.FieldChainTxID, 39 enttx.FieldType, 40 enttx.FieldExtra, 41 enttx.FieldCreatedAt, 42 enttx.FieldUpdatedAt, 43 ) 44 } 45 46 func (h *queryHandler) queryTx(cli *ent.Client) error { 47 if h.ID == nil && h.EntID == nil { 48 return fmt.Errorf("invalid id") 49 } 50 stm := cli.Tran.Query().Where(enttx.DeletedAt(0)) 51 if h.ID != nil { 52 stm.Where(enttx.ID(*h.ID)) 53 } 54 if h.EntID != nil { 55 stm.Where(enttx.EntID(*h.EntID)) 56 } 57 h.selectTx(stm) 58 return nil 59 } 60 61 func (h *queryHandler) queryTxs(ctx context.Context, cli *ent.Client) error { 62 stm, err := txcrud.SetQueryConds(cli.Tran.Query(), h.Conds) 63 if err != nil { 64 return err 65 } 66 67 total, err := stm.Count(ctx) 68 if err != nil { 69 return err 70 } 71 h.total = uint32(total) 72 73 h.selectTx(stm) 74 return nil 75 } 76 77 func (h *queryHandler) queryJoinCoin(s *sql.Selector) { 78 t := sql.Table(entcoinbase.Table) 79 s. 80 LeftJoin(t). 81 On( 82 s.C(enttx.FieldCoinTypeID), 83 t.C(entcoinbase.FieldEntID), 84 ). 85 AppendSelect( 86 sql.As(t.C(entcoinbase.FieldName), "coin_name"), 87 sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"), 88 sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"), 89 sql.As(t.C(entcoinbase.FieldEnv), "coin_env"), 90 ) 91 } 92 93 func (h *queryHandler) queryJoin() { 94 h.stm.Modify(func(s *sql.Selector) { 95 h.queryJoinCoin(s) 96 }) 97 } 98 99 func (h *queryHandler) scan(ctx context.Context) error { 100 return h.stm.Scan(ctx, &h.infos) 101 } 102 103 func (h *queryHandler) formalize() { 104 for _, info := range h.infos { 105 info.Type = basetypes.TxType(basetypes.TxType_value[info.TypeStr]) 106 info.State = basetypes.TxState(basetypes.TxState_value[info.StateStr]) 107 amount, err := decimal.NewFromString(info.Amount) 108 if err != nil { 109 info.Amount = decimal.NewFromInt(0).String() 110 } else { 111 info.Amount = amount.String() 112 } 113 amount, err = decimal.NewFromString(info.FeeAmount) 114 if err != nil { 115 info.FeeAmount = decimal.NewFromInt(0).String() 116 } else { 117 info.FeeAmount = amount.String() 118 } 119 } 120 } 121 122 func (h *Handler) GetTx(ctx context.Context) (*npool.Tx, error) { 123 handler := &queryHandler{ 124 Handler: h, 125 } 126 127 err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 128 if err := handler.queryTx(cli); err != nil { 129 return err 130 } 131 handler.queryJoin() 132 const singleRowLimit = 2 133 handler.stm.Offset(0).Limit(singleRowLimit) 134 return handler.scan(_ctx) 135 }) 136 if err != nil { 137 return nil, err 138 } 139 if len(handler.infos) == 0 { 140 return nil, nil 141 } 142 if len(handler.infos) > 1 { 143 return nil, fmt.Errorf("too many record") 144 } 145 146 handler.formalize() 147 return handler.infos[0], nil 148 } 149 150 func (h *Handler) GetTxs(ctx context.Context) ([]*npool.Tx, uint32, error) { 151 handler := &queryHandler{ 152 Handler: h, 153 } 154 155 err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 156 if err := handler.queryTxs(ctx, cli); err != nil { 157 return err 158 } 159 handler.queryJoin() 160 handler.stm. 161 Offset(int(h.Offset)). 162 Limit(int(h.Limit)). 163 Order(ent.Desc(enttx.FieldUpdatedAt)) 164 return handler.scan(_ctx) 165 }) 166 if err != nil { 167 return nil, 0, err 168 } 169 170 handler.formalize() 171 return handler.infos, handler.total, nil 172 }