github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/coin/currency/query.go (about) 1 package currency 2 3 import ( 4 "context" 5 "fmt" 6 "time" 7 8 basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1" 9 npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/coin/currency" 10 11 "github.com/NpoolPlatform/chain-middleware/pkg/db" 12 "github.com/NpoolPlatform/chain-middleware/pkg/db/ent" 13 "github.com/NpoolPlatform/libent-cruder/pkg/cruder" 14 15 coincrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin" 16 currencycrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin/currency" 17 entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase" 18 entcoinextra "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinextra" 19 entcurrency "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/currency" 20 21 "entgo.io/ent/dialect/sql" 22 "github.com/google/uuid" 23 "github.com/shopspring/decimal" 24 ) 25 26 type queryHandler struct { 27 *Handler 28 stmSelect *ent.CoinBaseSelect 29 stmCount *ent.CoinBaseSelect 30 infos []*npool.Currency 31 total uint32 32 } 33 34 func (h *queryHandler) selectCoinBase(stm *ent.CoinBaseQuery) *ent.CoinBaseSelect { 35 return stm.Select(entcoinbase.FieldCreatedAt) 36 } 37 38 func (h *queryHandler) queryCoinBase(ctx context.Context, cli *ent.Client) error { 39 _stm1, err := currencycrud.SetQueryConds(cli.Currency.Query(), ¤cycrud.Conds{ 40 EntID: &cruder.Cond{Op: cruder.EQ, Val: *h.EntID}, 41 }) 42 if err != nil { 43 return err 44 } 45 _info1, err := _stm1.Only(ctx) 46 if err != nil { 47 return err 48 } 49 50 _stm2, err := coincrud.SetQueryConds(cli.CoinBase.Query(), &coincrud.Conds{ 51 EntID: &cruder.Cond{Op: cruder.EQ, Val: _info1.CoinTypeID}, 52 }) 53 if err != nil { 54 return err 55 } 56 57 h.stmSelect = h.selectCoinBase(_stm2) 58 return nil 59 } 60 61 func (h *queryHandler) queryCoinBases(cli *ent.Client) (*ent.CoinBaseSelect, error) { 62 stm, err := coincrud.SetQueryConds(cli.CoinBase.Query(), &coincrud.Conds{ 63 EntID: h.Conds.CoinTypeID, 64 EntIDs: h.Conds.CoinTypeIDs, 65 }) 66 if err != nil { 67 return nil, err 68 } 69 70 return h.selectCoinBase(stm), nil 71 } 72 73 func (h *queryHandler) queryJoinMyself(s *sql.Selector) { 74 t := sql.Table(entcoinbase.Table) 75 s.AppendSelect( 76 sql.As(t.C(entcoinbase.FieldEntID), "coin_type_id"), 77 sql.As(t.C(entcoinbase.FieldName), "coin_name"), 78 sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"), 79 sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"), 80 sql.As(t.C(entcoinbase.FieldEnv), "coin_env"), 81 ) 82 } 83 84 func (h *queryHandler) queryJoinCoinExtra(s *sql.Selector) { 85 t := sql.Table(entcoinextra.Table) 86 s.LeftJoin(t). 87 On( 88 s.C(entcoinbase.FieldEntID), 89 t.C(entcoinextra.FieldCoinTypeID), 90 ). 91 AppendSelect( 92 sql.As(t.C(entcoinextra.FieldStableUsd), "stable_usd"), 93 ) 94 } 95 96 func (h *queryHandler) queryJoinCurrency(s *sql.Selector) error { 97 t := sql.Table(entcurrency.Table) 98 s.LeftJoin(t). 99 On( 100 s.C(entcoinbase.FieldEntID), 101 t.C(entcurrency.FieldCoinTypeID), 102 ). 103 OnP( 104 sql.EQ(t.C(entcurrency.FieldDeletedAt), 0), 105 ). 106 AppendSelect( 107 sql.As(t.C(entcurrency.FieldID), "id"), 108 sql.As(t.C(entcurrency.FieldEntID), "ent_id"), 109 sql.As(t.C(entcurrency.FieldFeedType), "feed_type"), 110 sql.As(t.C(entcurrency.FieldMarketValueHigh), "market_value_high"), 111 sql.As(t.C(entcurrency.FieldMarketValueLow), "market_value_low"), 112 sql.As(t.C(entcurrency.FieldCreatedAt), "created_at"), 113 sql.As(t.C(entcurrency.FieldUpdatedAt), "updated_at"), 114 ) 115 116 if h.Conds != nil && h.Conds.EntID != nil { 117 id, ok := h.Conds.EntID.Val.(uuid.UUID) 118 if !ok { 119 return fmt.Errorf("invalid entid") 120 } 121 switch h.Conds.EntID.Op { 122 case cruder.EQ: 123 s.Where( 124 sql.EQ(t.C(entcurrency.FieldEntID), id), 125 ) 126 default: 127 return fmt.Errorf("invalid currency field op") 128 } 129 } 130 if h.Conds != nil && h.Conds.FeedType != nil { 131 feedType, ok := h.Conds.FeedType.Val.(basetypes.CurrencyFeedType) 132 if !ok { 133 return fmt.Errorf("invalid feedtype") 134 } 135 switch h.Conds.FeedType.Op { 136 case cruder.EQ: 137 s.Where( 138 sql.EQ(t.C(entcurrency.FieldFeedType), feedType.String()), 139 ) 140 default: 141 return fmt.Errorf("invalid currency field op") 142 } 143 } 144 return nil 145 } 146 147 func (h *queryHandler) queryJoin() { 148 h.stmSelect.Modify(func(s *sql.Selector) { 149 h.queryJoinMyself(s) 150 h.queryJoinCoinExtra(s) 151 if err := h.queryJoinCurrency(s); err != nil { 152 return 153 } 154 }) 155 if h.stmCount == nil { 156 return 157 } 158 h.stmCount.Modify(func(s *sql.Selector) { 159 if err := h.queryJoinCurrency(s); err != nil { 160 return 161 } 162 }) 163 } 164 165 func (h *queryHandler) scan(ctx context.Context) error { 166 return h.stmSelect.Scan(ctx, &h.infos) 167 } 168 169 func (h *queryHandler) formalize() { 170 for _, info := range h.infos { 171 info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr]) 172 if info.StableUSD { 173 info.MarketValueHigh = decimal.NewFromInt(1).String() 174 info.MarketValueLow = decimal.NewFromInt(1).String() 175 info.CreatedAt = uint32(time.Now().Unix()) 176 info.UpdatedAt = uint32(time.Now().Unix()) 177 info.FeedType = basetypes.CurrencyFeedType_StableUSDHardCode 178 } 179 if _, err := decimal.NewFromString(info.MarketValueHigh); err != nil { 180 info.MarketValueHigh = decimal.NewFromInt(0).String() 181 } 182 if _, err := decimal.NewFromString(info.MarketValueLow); err != nil { 183 info.MarketValueLow = decimal.NewFromInt(0).String() 184 } 185 if _, err := uuid.Parse(info.EntID); err != nil { 186 info.EntID = uuid.Nil.String() 187 } 188 } 189 } 190 191 func (h *Handler) GetCurrency(ctx context.Context) (*npool.Currency, error) { 192 if h.EntID == nil { 193 return nil, fmt.Errorf("invalid entid") 194 } 195 196 handler := &queryHandler{ 197 Handler: h, 198 } 199 200 err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 201 if err := handler.queryCoinBase(_ctx, cli); err != nil { 202 return err 203 } 204 handler.queryJoin() 205 const singleRowLimit = 1 206 handler.stmSelect. 207 Offset(0). 208 Limit(singleRowLimit) 209 return handler.scan(_ctx) 210 }) 211 if err != nil { 212 return nil, err 213 } 214 if len(handler.infos) == 0 { 215 return nil, nil 216 } 217 218 handler.formalize() 219 return handler.infos[0], nil 220 } 221 222 func (h *Handler) GetCurrencies(ctx context.Context) ([]*npool.Currency, uint32, error) { 223 handler := &queryHandler{ 224 Handler: h, 225 } 226 227 var err error 228 err = db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error { 229 handler.stmSelect, err = handler.queryCoinBases(cli) 230 if err != nil { 231 return err 232 } 233 handler.stmCount, err = handler.queryCoinBases(cli) 234 if err != nil { 235 return err 236 } 237 238 handler.queryJoin() 239 _total, err := handler.stmCount.Count(_ctx) 240 if err != nil { 241 return err 242 } 243 handler.total = uint32(_total) 244 245 handler.stmSelect. 246 Offset(int(h.Offset)). 247 Limit(int(h.Limit)) 248 return handler.scan(_ctx) 249 }) 250 if err != nil { 251 return nil, 0, err 252 } 253 254 handler.formalize() 255 return handler.infos, handler.total, nil 256 }