github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/coin/currency/query.go (about)

     1  package currency
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     9  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/coin/currency"
    10  
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    12  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    13  	"github.com/NpoolPlatform/libent-cruder/pkg/cruder"
    14  
    15  	coincrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin"
    16  	currencycrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/coin/currency"
    17  	entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase"
    18  	entcoinextra "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinextra"
    19  	entcurrency "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/currency"
    20  
    21  	"entgo.io/ent/dialect/sql"
    22  	"github.com/google/uuid"
    23  	"github.com/shopspring/decimal"
    24  )
    25  
    26  type queryHandler struct {
    27  	*Handler
    28  	stmSelect *ent.CoinBaseSelect
    29  	stmCount  *ent.CoinBaseSelect
    30  	infos     []*npool.Currency
    31  	total     uint32
    32  }
    33  
    34  func (h *queryHandler) selectCoinBase(stm *ent.CoinBaseQuery) *ent.CoinBaseSelect {
    35  	return stm.Select(entcoinbase.FieldCreatedAt)
    36  }
    37  
    38  func (h *queryHandler) queryCoinBase(ctx context.Context, cli *ent.Client) error {
    39  	_stm1, err := currencycrud.SetQueryConds(cli.Currency.Query(), &currencycrud.Conds{
    40  		EntID: &cruder.Cond{Op: cruder.EQ, Val: *h.EntID},
    41  	})
    42  	if err != nil {
    43  		return err
    44  	}
    45  	_info1, err := _stm1.Only(ctx)
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	_stm2, err := coincrud.SetQueryConds(cli.CoinBase.Query(), &coincrud.Conds{
    51  		EntID: &cruder.Cond{Op: cruder.EQ, Val: _info1.CoinTypeID},
    52  	})
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	h.stmSelect = h.selectCoinBase(_stm2)
    58  	return nil
    59  }
    60  
    61  func (h *queryHandler) queryCoinBases(cli *ent.Client) (*ent.CoinBaseSelect, error) {
    62  	stm, err := coincrud.SetQueryConds(cli.CoinBase.Query(), &coincrud.Conds{
    63  		EntID:  h.Conds.CoinTypeID,
    64  		EntIDs: h.Conds.CoinTypeIDs,
    65  	})
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	return h.selectCoinBase(stm), nil
    71  }
    72  
    73  func (h *queryHandler) queryJoinMyself(s *sql.Selector) {
    74  	t := sql.Table(entcoinbase.Table)
    75  	s.AppendSelect(
    76  		sql.As(t.C(entcoinbase.FieldEntID), "coin_type_id"),
    77  		sql.As(t.C(entcoinbase.FieldName), "coin_name"),
    78  		sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"),
    79  		sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"),
    80  		sql.As(t.C(entcoinbase.FieldEnv), "coin_env"),
    81  	)
    82  }
    83  
    84  func (h *queryHandler) queryJoinCoinExtra(s *sql.Selector) {
    85  	t := sql.Table(entcoinextra.Table)
    86  	s.LeftJoin(t).
    87  		On(
    88  			s.C(entcoinbase.FieldEntID),
    89  			t.C(entcoinextra.FieldCoinTypeID),
    90  		).
    91  		AppendSelect(
    92  			sql.As(t.C(entcoinextra.FieldStableUsd), "stable_usd"),
    93  		)
    94  }
    95  
    96  func (h *queryHandler) queryJoinCurrency(s *sql.Selector) error {
    97  	t := sql.Table(entcurrency.Table)
    98  	s.LeftJoin(t).
    99  		On(
   100  			s.C(entcoinbase.FieldEntID),
   101  			t.C(entcurrency.FieldCoinTypeID),
   102  		).
   103  		OnP(
   104  			sql.EQ(t.C(entcurrency.FieldDeletedAt), 0),
   105  		).
   106  		AppendSelect(
   107  			sql.As(t.C(entcurrency.FieldID), "id"),
   108  			sql.As(t.C(entcurrency.FieldEntID), "ent_id"),
   109  			sql.As(t.C(entcurrency.FieldFeedType), "feed_type"),
   110  			sql.As(t.C(entcurrency.FieldMarketValueHigh), "market_value_high"),
   111  			sql.As(t.C(entcurrency.FieldMarketValueLow), "market_value_low"),
   112  			sql.As(t.C(entcurrency.FieldCreatedAt), "created_at"),
   113  			sql.As(t.C(entcurrency.FieldUpdatedAt), "updated_at"),
   114  		)
   115  
   116  	if h.Conds != nil && h.Conds.EntID != nil {
   117  		id, ok := h.Conds.EntID.Val.(uuid.UUID)
   118  		if !ok {
   119  			return fmt.Errorf("invalid entid")
   120  		}
   121  		switch h.Conds.EntID.Op {
   122  		case cruder.EQ:
   123  			s.Where(
   124  				sql.EQ(t.C(entcurrency.FieldEntID), id),
   125  			)
   126  		default:
   127  			return fmt.Errorf("invalid currency field op")
   128  		}
   129  	}
   130  	if h.Conds != nil && h.Conds.FeedType != nil {
   131  		feedType, ok := h.Conds.FeedType.Val.(basetypes.CurrencyFeedType)
   132  		if !ok {
   133  			return fmt.Errorf("invalid feedtype")
   134  		}
   135  		switch h.Conds.FeedType.Op {
   136  		case cruder.EQ:
   137  			s.Where(
   138  				sql.EQ(t.C(entcurrency.FieldFeedType), feedType.String()),
   139  			)
   140  		default:
   141  			return fmt.Errorf("invalid currency field op")
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  func (h *queryHandler) queryJoin() {
   148  	h.stmSelect.Modify(func(s *sql.Selector) {
   149  		h.queryJoinMyself(s)
   150  		h.queryJoinCoinExtra(s)
   151  		if err := h.queryJoinCurrency(s); err != nil {
   152  			return
   153  		}
   154  	})
   155  	if h.stmCount == nil {
   156  		return
   157  	}
   158  	h.stmCount.Modify(func(s *sql.Selector) {
   159  		if err := h.queryJoinCurrency(s); err != nil {
   160  			return
   161  		}
   162  	})
   163  }
   164  
   165  func (h *queryHandler) scan(ctx context.Context) error {
   166  	return h.stmSelect.Scan(ctx, &h.infos)
   167  }
   168  
   169  func (h *queryHandler) formalize() {
   170  	for _, info := range h.infos {
   171  		info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr])
   172  		if info.StableUSD {
   173  			info.MarketValueHigh = decimal.NewFromInt(1).String()
   174  			info.MarketValueLow = decimal.NewFromInt(1).String()
   175  			info.CreatedAt = uint32(time.Now().Unix())
   176  			info.UpdatedAt = uint32(time.Now().Unix())
   177  			info.FeedType = basetypes.CurrencyFeedType_StableUSDHardCode
   178  		}
   179  		if _, err := decimal.NewFromString(info.MarketValueHigh); err != nil {
   180  			info.MarketValueHigh = decimal.NewFromInt(0).String()
   181  		}
   182  		if _, err := decimal.NewFromString(info.MarketValueLow); err != nil {
   183  			info.MarketValueLow = decimal.NewFromInt(0).String()
   184  		}
   185  		if _, err := uuid.Parse(info.EntID); err != nil {
   186  			info.EntID = uuid.Nil.String()
   187  		}
   188  	}
   189  }
   190  
   191  func (h *Handler) GetCurrency(ctx context.Context) (*npool.Currency, error) {
   192  	if h.EntID == nil {
   193  		return nil, fmt.Errorf("invalid entid")
   194  	}
   195  
   196  	handler := &queryHandler{
   197  		Handler: h,
   198  	}
   199  
   200  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   201  		if err := handler.queryCoinBase(_ctx, cli); err != nil {
   202  			return err
   203  		}
   204  		handler.queryJoin()
   205  		const singleRowLimit = 1
   206  		handler.stmSelect.
   207  			Offset(0).
   208  			Limit(singleRowLimit)
   209  		return handler.scan(_ctx)
   210  	})
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	if len(handler.infos) == 0 {
   215  		return nil, nil
   216  	}
   217  
   218  	handler.formalize()
   219  	return handler.infos[0], nil
   220  }
   221  
   222  func (h *Handler) GetCurrencies(ctx context.Context) ([]*npool.Currency, uint32, error) {
   223  	handler := &queryHandler{
   224  		Handler: h,
   225  	}
   226  
   227  	var err error
   228  	err = db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   229  		handler.stmSelect, err = handler.queryCoinBases(cli)
   230  		if err != nil {
   231  			return err
   232  		}
   233  		handler.stmCount, err = handler.queryCoinBases(cli)
   234  		if err != nil {
   235  			return err
   236  		}
   237  
   238  		handler.queryJoin()
   239  		_total, err := handler.stmCount.Count(_ctx)
   240  		if err != nil {
   241  			return err
   242  		}
   243  		handler.total = uint32(_total)
   244  
   245  		handler.stmSelect.
   246  			Offset(int(h.Offset)).
   247  			Limit(int(h.Limit))
   248  		return handler.scan(_ctx)
   249  	})
   250  	if err != nil {
   251  		return nil, 0, err
   252  	}
   253  
   254  	handler.formalize()
   255  	return handler.infos, handler.total, nil
   256  }