github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/coin/currency/feed/query.go (about)

     1  package currencyfeed
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     8  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/coin/currency/feed"
     9  
    10  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    12  
    13  	currencyfeedcrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin/currency/feed"
    14  	entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase"
    15  	entcurrencyfeed "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/currencyfeed"
    16  
    17  	"entgo.io/ent/dialect/sql"
    18  )
    19  
    20  type queryHandler struct {
    21  	*Handler
    22  	stm   *ent.CurrencyFeedSelect
    23  	infos []*npool.Feed
    24  	total uint32
    25  }
    26  
    27  func (h *queryHandler) selectFeed(stm *ent.CurrencyFeedQuery) {
    28  	h.stm = stm.Select(
    29  		entcurrencyfeed.FieldID,
    30  		entcurrencyfeed.FieldEntID,
    31  		entcurrencyfeed.FieldCoinTypeID,
    32  		entcurrencyfeed.FieldFeedType,
    33  		entcurrencyfeed.FieldFeedCoinName,
    34  		entcurrencyfeed.FieldDisabled,
    35  		entcurrencyfeed.FieldCreatedAt,
    36  		entcurrencyfeed.FieldUpdatedAt,
    37  	)
    38  }
    39  
    40  func (h *queryHandler) queryFeed(cli *ent.Client) error {
    41  	if h.ID == nil && h.EntID == nil {
    42  		return fmt.Errorf("invalid id")
    43  	}
    44  	stm := cli.CurrencyFeed.Query().Where(entcurrencyfeed.DeletedAt(0))
    45  	if h.ID != nil {
    46  		stm.Where(entcurrencyfeed.ID(*h.ID))
    47  	}
    48  	if h.EntID != nil {
    49  		stm.Where(entcurrencyfeed.EntID(*h.EntID))
    50  	}
    51  	h.selectFeed(stm)
    52  	return nil
    53  }
    54  
    55  func (h *queryHandler) queryFeeds(ctx context.Context, cli *ent.Client) error {
    56  	stm, err := currencyfeedcrud.SetQueryConds(cli.CurrencyFeed.Query(), h.Conds)
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	total, err := stm.Count(ctx)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	h.total = uint32(total)
    66  
    67  	h.selectFeed(stm)
    68  	return nil
    69  }
    70  
    71  func (h *queryHandler) queryJoinCoin(s *sql.Selector) {
    72  	t := sql.Table(entcoinbase.Table)
    73  	s.
    74  		LeftJoin(t).
    75  		On(
    76  			s.C(entcurrencyfeed.FieldCoinTypeID),
    77  			t.C(entcoinbase.FieldEntID),
    78  		).
    79  		AppendSelect(
    80  			sql.As(t.C(entcoinbase.FieldName), "coin_name"),
    81  			sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"),
    82  			sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"),
    83  			sql.As(t.C(entcoinbase.FieldEnv), "coin_env"),
    84  		)
    85  }
    86  
    87  func (h *queryHandler) queryJoin() {
    88  	h.stm.Modify(func(s *sql.Selector) {
    89  		h.queryJoinCoin(s)
    90  	})
    91  }
    92  
    93  func (h *queryHandler) scan(ctx context.Context) error {
    94  	return h.stm.Scan(ctx, &h.infos)
    95  }
    96  
    97  func (h *queryHandler) formalize() {
    98  	for _, info := range h.infos {
    99  		info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr])
   100  	}
   101  }
   102  
   103  func (h *Handler) GetFeed(ctx context.Context) (*npool.Feed, error) {
   104  	handler := &queryHandler{
   105  		Handler: h,
   106  	}
   107  
   108  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   109  		if err := handler.queryFeed(cli); err != nil {
   110  			return err
   111  		}
   112  		handler.queryJoin()
   113  		const singleRowLimit = 2
   114  		handler.stm.Offset(0).Limit(singleRowLimit)
   115  		return handler.scan(_ctx)
   116  	})
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	if len(handler.infos) == 0 {
   121  		return nil, nil
   122  	}
   123  	if len(handler.infos) > 1 {
   124  		return nil, fmt.Errorf("too many record")
   125  	}
   126  
   127  	handler.formalize()
   128  	return handler.infos[0], nil
   129  }
   130  
   131  func (h *Handler) GetFeeds(ctx context.Context) ([]*npool.Feed, uint32, error) {
   132  	handler := &queryHandler{
   133  		Handler: h,
   134  	}
   135  
   136  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   137  		if err := handler.queryFeeds(ctx, cli); err != nil {
   138  			return err
   139  		}
   140  		handler.queryJoin()
   141  		handler.stm.
   142  			Offset(int(h.Offset)).
   143  			Limit(int(h.Limit))
   144  		return handler.scan(_ctx)
   145  	})
   146  	if err != nil {
   147  		return nil, 0, err
   148  	}
   149  
   150  	handler.formalize()
   151  	return handler.infos, handler.total, nil
   152  }