github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/coin/currency/history/query.go (about)

     1  package currencyhistory
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     8  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/coin/currency"
     9  
    10  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    12  
    13  	historycrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin/currency/history"
    14  	entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase"
    15  	entcurrencyhis "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/currencyhistory"
    16  
    17  	"entgo.io/ent/dialect/sql"
    18  )
    19  
    20  type queryHandler struct {
    21  	*Handler
    22  	stmSelect *ent.CurrencyHistorySelect
    23  	stmCount  *ent.CurrencyHistorySelect
    24  	infos     []*npool.Currency
    25  	total     uint32
    26  }
    27  
    28  func (h *queryHandler) selectCurrencyHistory(stm *ent.CurrencyHistoryQuery) *ent.CurrencyHistorySelect {
    29  	return stm.Select(entcurrencyhis.FieldID)
    30  }
    31  
    32  func (h *queryHandler) queryCurrencyHistories(cli *ent.Client) (*ent.CurrencyHistorySelect, error) {
    33  	stm, err := historycrud.SetQueryConds(cli.CurrencyHistory.Query(), h.Conds)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	return h.selectCurrencyHistory(stm), nil
    39  }
    40  
    41  func (h *queryHandler) queryJoinMyself(s *sql.Selector) {
    42  	t := sql.Table(entcurrencyhis.Table)
    43  	s.AppendSelect(
    44  		sql.As(t.C(entcurrencyhis.FieldEntID), "ent_id"),
    45  		sql.As(t.C(entcurrencyhis.FieldCoinTypeID), "coin_type_id"),
    46  		sql.As(t.C(entcurrencyhis.FieldFeedType), "feed_type"),
    47  		sql.As(t.C(entcurrencyhis.FieldMarketValueHigh), "market_value_high"),
    48  		sql.As(t.C(entcurrencyhis.FieldMarketValueLow), "market_value_low"),
    49  		sql.As(t.C(entcurrencyhis.FieldCreatedAt), "created_at"),
    50  		sql.As(t.C(entcurrencyhis.FieldUpdatedAt), "updated_at"),
    51  	)
    52  }
    53  
    54  func (h *queryHandler) queryJoinCoin(s *sql.Selector) error {
    55  	t := sql.Table(entcoinbase.Table)
    56  	s.LeftJoin(t).
    57  		On(
    58  			s.C(entcurrencyhis.FieldCoinTypeID),
    59  			t.C(entcoinbase.FieldEntID),
    60  		)
    61  
    62  	if h.Conds.CoinNames != nil {
    63  		names, ok := h.Conds.CoinNames.Val.([]string)
    64  		if !ok {
    65  			return fmt.Errorf("invalid coinnames")
    66  		}
    67  		_names := []interface{}{}
    68  		for _, _name := range names {
    69  			_names = append(_names, _name)
    70  		}
    71  		s.Where(
    72  			sql.In(t.C(entcoinbase.FieldName), _names...),
    73  		)
    74  	}
    75  
    76  	s.AppendSelect(
    77  		sql.As(t.C(entcoinbase.FieldName), "coin_name"),
    78  		sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"),
    79  		sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"),
    80  		sql.As(t.C(entcoinbase.FieldEnv), "coin_env"),
    81  	)
    82  	return nil
    83  }
    84  
    85  func (h *queryHandler) queryJoin() (err error) {
    86  	h.stmSelect.Modify(func(s *sql.Selector) {
    87  		h.queryJoinMyself(s)
    88  		if err = h.queryJoinCoin(s); err != nil {
    89  			return
    90  		}
    91  	})
    92  	h.stmCount.Modify(func(s *sql.Selector) {
    93  		if err = h.queryJoinCoin(s); err != nil {
    94  			return
    95  		}
    96  	})
    97  	return err
    98  }
    99  
   100  func (h *queryHandler) scan(ctx context.Context) error {
   101  	return h.stmSelect.Scan(ctx, &h.infos)
   102  }
   103  
   104  func (h *queryHandler) formalize() {
   105  	for _, info := range h.infos {
   106  		info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr])
   107  	}
   108  }
   109  
   110  func (h *Handler) GetCurrencies(ctx context.Context) ([]*npool.Currency, uint32, error) {
   111  	handler := &queryHandler{
   112  		Handler: h,
   113  	}
   114  
   115  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   116  		_stm, err := handler.queryCurrencyHistories(cli)
   117  		if err != nil {
   118  			return err
   119  		}
   120  		handler.stmSelect = _stm
   121  
   122  		_stm, err = handler.queryCurrencyHistories(cli)
   123  		if err != nil {
   124  			return err
   125  		}
   126  		handler.stmCount = _stm
   127  
   128  		if err := handler.queryJoin(); err != nil {
   129  			return err
   130  		}
   131  
   132  		_total, err := handler.stmCount.Count(ctx)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		handler.total = uint32(_total)
   137  
   138  		handler.stmSelect.
   139  			Order(ent.Asc(entcurrencyhis.FieldCreatedAt)).
   140  			Offset(int(h.Offset)).
   141  			Limit(int(h.Limit))
   142  		return handler.scan(_ctx)
   143  	})
   144  	if err != nil {
   145  		return nil, 0, err
   146  	}
   147  
   148  	handler.formalize()
   149  	return handler.infos, handler.total, nil
   150  }