github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/coin/fiat/currency/query.go (about)

     1  package currency
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     8  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/coin/fiat/currency"
     9  
    10  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    12  
    13  	currencycrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin/fiat/currency"
    14  	entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase"
    15  	entcurrency "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinfiatcurrency"
    16  	entfiat "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/fiat"
    17  
    18  	"entgo.io/ent/dialect/sql"
    19  )
    20  
    21  type queryHandler struct {
    22  	*Handler
    23  	stm   *ent.CoinFiatCurrencySelect
    24  	infos []*npool.Currency
    25  	total uint32
    26  }
    27  
    28  func (h *queryHandler) selectCurrency(stm *ent.CoinFiatCurrencyQuery) {
    29  	h.stm = stm.Select(
    30  		entcurrency.FieldID,
    31  		entcurrency.FieldEntID,
    32  		entcurrency.FieldCoinTypeID,
    33  		entcurrency.FieldFiatID,
    34  		entcurrency.FieldFeedType,
    35  		entcurrency.FieldMarketValueHigh,
    36  		entcurrency.FieldMarketValueLow,
    37  		entcurrency.FieldCreatedAt,
    38  		entcurrency.FieldUpdatedAt,
    39  	)
    40  }
    41  
    42  func (h *queryHandler) queryCurrency(cli *ent.Client) error {
    43  	if h.ID == nil && h.EntID == nil {
    44  		return fmt.Errorf("invalid id")
    45  	}
    46  	stm := cli.CoinFiatCurrency.Query().Where(entcurrency.DeletedAt(0))
    47  	if h.ID != nil {
    48  		stm.Where(entcurrency.ID(*h.ID))
    49  	}
    50  	if h.EntID != nil {
    51  		stm.Where(entcurrency.EntID(*h.EntID))
    52  	}
    53  	h.selectCurrency(stm)
    54  	return nil
    55  }
    56  
    57  func (h *queryHandler) queryCurrencies(ctx context.Context, cli *ent.Client) error {
    58  	stm, err := currencycrud.SetQueryConds(cli.CoinFiatCurrency.Query(), h.Conds)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	total, err := stm.Count(ctx)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	h.total = uint32(total)
    68  
    69  	h.selectCurrency(stm)
    70  	return nil
    71  }
    72  
    73  func (h *queryHandler) queryJoinCoin(s *sql.Selector) {
    74  	t := sql.Table(entcoinbase.Table)
    75  	s.
    76  		LeftJoin(t).
    77  		On(
    78  			s.C(entcurrency.FieldCoinTypeID),
    79  			t.C(entcoinbase.FieldEntID),
    80  		).
    81  		AppendSelect(
    82  			sql.As(t.C(entcoinbase.FieldName), "coin_name"),
    83  			sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"),
    84  			sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"),
    85  			sql.As(t.C(entcoinbase.FieldEnv), "coin_env"),
    86  		)
    87  }
    88  
    89  func (h *queryHandler) queryJoinFiat(s *sql.Selector) {
    90  	t := sql.Table(entfiat.Table)
    91  	s.
    92  		LeftJoin(t).
    93  		On(
    94  			s.C(entcurrency.FieldFiatID),
    95  			t.C(entfiat.FieldEntID),
    96  		).
    97  		AppendSelect(
    98  			sql.As(t.C(entfiat.FieldName), "fiat_name"),
    99  			sql.As(t.C(entfiat.FieldLogo), "fiat_logo"),
   100  			sql.As(t.C(entfiat.FieldUnit), "fiat_unit"),
   101  		)
   102  }
   103  
   104  func (h *queryHandler) queryJoin() {
   105  	h.stm.Modify(func(s *sql.Selector) {
   106  		h.queryJoinCoin(s)
   107  		h.queryJoinFiat(s)
   108  	})
   109  }
   110  
   111  func (h *queryHandler) scan(ctx context.Context) error {
   112  	return h.stm.Scan(ctx, &h.infos)
   113  }
   114  
   115  func (h *queryHandler) formalize() {
   116  	for _, info := range h.infos {
   117  		info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr])
   118  	}
   119  }
   120  
   121  func (h *Handler) GetCurrency(ctx context.Context) (*npool.Currency, error) {
   122  	handler := &queryHandler{
   123  		Handler: h,
   124  	}
   125  
   126  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   127  		if err := handler.queryCurrency(cli); err != nil {
   128  			return err
   129  		}
   130  		handler.queryJoin()
   131  		const singleRowLimit = 2
   132  		handler.stm.Offset(0).Limit(singleRowLimit)
   133  		return handler.scan(_ctx)
   134  	})
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	if len(handler.infos) == 0 {
   139  		return nil, nil
   140  	}
   141  	if len(handler.infos) > 1 {
   142  		return nil, fmt.Errorf("too many record")
   143  	}
   144  
   145  	handler.formalize()
   146  	return handler.infos[0], nil
   147  }
   148  
   149  func (h *Handler) GetCurrencies(ctx context.Context) ([]*npool.Currency, uint32, error) {
   150  	handler := &queryHandler{
   151  		Handler: h,
   152  	}
   153  
   154  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   155  		if err := handler.queryCurrencies(ctx, cli); err != nil {
   156  			return err
   157  		}
   158  		handler.queryJoin()
   159  		handler.stm.
   160  			Offset(int(h.Offset)).
   161  			Limit(int(h.Limit))
   162  		return handler.scan(_ctx)
   163  	})
   164  	if err != nil {
   165  		return nil, 0, err
   166  	}
   167  
   168  	handler.formalize()
   169  	return handler.infos, handler.total, nil
   170  }