github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/chain/query.go (about) 1 package chain 2 3 import ( 4 "context" 5 "fmt" 6 7 "entgo.io/ent/dialect/sql" 8 chaincrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/chain" 9 "github.com/NpoolPlatform/chain-middleware/pkg/db" 10 "github.com/NpoolPlatform/chain-middleware/pkg/db/ent" 11 basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1" 12 npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/chain" 13 14 entchain "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/chainbase" 15 ) 16 17 type queryHandler struct { 18 *Handler 19 stm *ent.ChainBaseSelect 20 infos []*npool.Chain 21 total uint32 22 } 23 24 func (h *queryHandler) selectChainBase(stm *ent.ChainBaseQuery) { 25 h.stm = stm.Select( 26 entchain.FieldID, 27 entchain.FieldEntID, 28 entchain.FieldLogo, 29 entchain.FieldNativeUnit, 30 entchain.FieldAtomicUnit, 31 entchain.FieldUnitExp, 32 entchain.FieldEnv, 33 entchain.FieldChainID, 34 entchain.FieldNickname, 35 entchain.FieldGasType, 36 entchain.FieldCreatedAt, 37 entchain.FieldUpdatedAt, 38 ).Modify(func(s *sql.Selector) { 39 t := sql.Table(entchain.Table) 40 s.AppendSelect( 41 sql.As(t.C(entchain.FieldName), "chain_type"), 42 ) 43 }) 44 } 45 46 func (h *queryHandler) queryChainBase(cli *ent.Client) error { 47 if h.ID == nil && h.EntID == nil { 48 return fmt.Errorf("invalid id") 49 } 50 stm := cli.ChainBase.Query().Where(entchain.DeletedAt(0)) 51 if h.ID != nil { 52 stm.Where(entchain.ID(*h.ID)) 53 } 54 if h.EntID != nil { 55 stm.Where(entchain.EntID(*h.EntID)) 56 } 57 h.selectChainBase(stm) 58 return nil 59 } 60 61 func (h *queryHandler) queryChainBases(ctx context.Context, cli *ent.Client) error { 62 stm, err := chaincrud.SetQueryConds(cli.ChainBase.Query(), h.Conds) 63 if err != nil { 64 return err 65 } 66 total, err := stm.Count(ctx) 67 if err != nil { 68 return err 69 } 70 h.total = uint32(total) 71 h.selectChainBase(stm) 72 return nil 73 } 74 75 func (h *queryHandler) queryJoin() { 76 h.stm.Modify(func(s *sql.Selector) {}) 77 } 78 79 func (h *queryHandler) scan(ctx context.Context) error { 80 return h.stm.Scan(ctx, &h.infos) 81 } 82 83 func (h *queryHandler) formalize() { 84 for _, info := range h.infos { 85 info.GasType = basetypes.GasType(basetypes.GasType_value[info.GasTypeStr]) 86 } 87 } 88 89 func (h *Handler) GetChain(ctx context.Context) (*npool.Chain, error) { 90 handler := &queryHandler{ 91 Handler: h, 92 } 93 94 err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 95 if err := handler.queryChainBase(cli); err != nil { 96 return err 97 } 98 handler.queryJoin() 99 const singleRowLimit = 2 100 handler.stm.Offset(0).Limit(singleRowLimit) 101 return handler.scan(_ctx) 102 }) 103 if err != nil { 104 return nil, err 105 } 106 if len(handler.infos) == 0 { 107 return nil, nil 108 } 109 if len(handler.infos) > 1 { 110 return nil, fmt.Errorf("too many record") 111 } 112 113 handler.formalize() 114 return handler.infos[0], nil 115 } 116 117 func (h *Handler) GetChains(ctx context.Context) ([]*npool.Chain, uint32, error) { 118 handler := &queryHandler{ 119 Handler: h, 120 } 121 122 err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 123 if err := handler.queryChainBases(_ctx, cli); err != nil { 124 return err 125 } 126 handler.queryJoin() 127 handler.stm. 128 Offset(int(h.Offset)). 129 Limit(int(h.Limit)) 130 return handler.scan(_ctx) 131 }) 132 if err != nil { 133 return nil, 0, err 134 } 135 136 handler.formalize() 137 return handler.infos, handler.total, nil 138 }