github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/fiat/currency/query.go (about)

     1  package currency
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     8  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/fiat/currency"
     9  	"github.com/google/uuid"
    10  
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    12  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    13  	"github.com/NpoolPlatform/libent-cruder/pkg/cruder"
    14  
    15  	fiatcrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/fiat"
    16  	currencycrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/fiat/currency"
    17  	entfiat "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/fiat"
    18  	entcurrency "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/fiatcurrency"
    19  
    20  	"entgo.io/ent/dialect/sql"
    21  	"github.com/shopspring/decimal"
    22  )
    23  
    24  type queryHandler struct {
    25  	*Handler
    26  	stmSelect *ent.FiatSelect
    27  	stmCount  *ent.FiatSelect
    28  	infos     []*npool.Currency
    29  	total     uint32
    30  }
    31  
    32  func (h *queryHandler) selectFiat(stm *ent.FiatQuery) *ent.FiatSelect {
    33  	return stm.Select(entfiat.FieldCreatedAt)
    34  }
    35  
    36  func (h *queryHandler) queryFiat(ctx context.Context, cli *ent.Client) error {
    37  	_stm1, err := currencycrud.SetQueryConds(cli.FiatCurrency.Query(), &currencycrud.Conds{
    38  		EntID: &cruder.Cond{Op: cruder.EQ, Val: *h.EntID},
    39  	})
    40  	if err != nil {
    41  		return err
    42  	}
    43  	_info1, err := _stm1.Only(ctx)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	_stm2, err := fiatcrud.SetQueryConds(cli.Fiat.Query(), &fiatcrud.Conds{
    49  		EntID: &cruder.Cond{Op: cruder.EQ, Val: _info1.FiatID},
    50  	})
    51  	if err != nil {
    52  		return err
    53  	}
    54  
    55  	h.stmSelect = h.selectFiat(_stm2)
    56  	return nil
    57  }
    58  
    59  func (h *queryHandler) queryFiats(cli *ent.Client) (*ent.FiatSelect, error) {
    60  	stm, err := fiatcrud.SetQueryConds(cli.Fiat.Query(), &fiatcrud.Conds{
    61  		EntID:  h.Conds.FiatID,
    62  		EntIDs: h.Conds.FiatIDs,
    63  		Name:   h.Conds.FiatName,
    64  	})
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	return h.selectFiat(stm), nil
    70  }
    71  
    72  func (h *queryHandler) queryJoinMyself(s *sql.Selector) {
    73  	t := sql.Table(entfiat.Table)
    74  	s.AppendSelect(
    75  		sql.As(t.C(entfiat.FieldEntID), "fiat_id"),
    76  		sql.As(t.C(entfiat.FieldName), "fiat_name"),
    77  		sql.As(t.C(entfiat.FieldLogo), "fiat_logo"),
    78  		sql.As(t.C(entfiat.FieldUnit), "fiat_unit"),
    79  	)
    80  }
    81  
    82  func (h *queryHandler) queryJoinCurrency(s *sql.Selector) error {
    83  	t := sql.Table(entcurrency.Table)
    84  	s.LeftJoin(t).
    85  		On(
    86  			s.C(entfiat.FieldEntID),
    87  			t.C(entcurrency.FieldFiatID),
    88  		).
    89  		OnP(
    90  			sql.EQ(t.C(entcurrency.FieldDeletedAt), 0),
    91  		).
    92  		AppendSelect(
    93  			sql.As(t.C(entcurrency.FieldID), "id"),
    94  			sql.As(t.C(entcurrency.FieldEntID), "ent_id"),
    95  			sql.As(t.C(entcurrency.FieldFeedType), "feed_type"),
    96  			sql.As(t.C(entcurrency.FieldMarketValueHigh), "market_value_high"),
    97  			sql.As(t.C(entcurrency.FieldMarketValueLow), "market_value_low"),
    98  			sql.As(t.C(entcurrency.FieldCreatedAt), "created_at"),
    99  			sql.As(t.C(entcurrency.FieldUpdatedAt), "updated_at"),
   100  		)
   101  
   102  	if h.Conds != nil && h.Conds.EntID != nil {
   103  		id, ok := h.Conds.EntID.Val.(uuid.UUID)
   104  		if !ok {
   105  			return fmt.Errorf("invalid entid")
   106  		}
   107  		switch h.Conds.EntID.Op {
   108  		case cruder.EQ:
   109  			s.Where(
   110  				sql.EQ(t.C(entcurrency.FieldEntID), id),
   111  			)
   112  		default:
   113  			return fmt.Errorf("invalid fiat currency field op")
   114  		}
   115  	}
   116  	if h.Conds != nil && h.Conds.FiatID != nil {
   117  		id, ok := h.Conds.FiatID.Val.(string)
   118  		if !ok {
   119  			return fmt.Errorf("invalid fiatid")
   120  		}
   121  		switch h.Conds.FiatID.Op {
   122  		case cruder.EQ:
   123  			s.Where(
   124  				sql.EQ(t.C(entcurrency.FieldFiatID), id),
   125  			)
   126  		default:
   127  			return fmt.Errorf("invalid fiat currency field op")
   128  		}
   129  	}
   130  	if h.Conds != nil && h.Conds.FiatIDs != nil {
   131  		ids, ok := h.Conds.FiatIDs.Val.([]string)
   132  		if !ok {
   133  			return fmt.Errorf("invalid fiatids")
   134  		}
   135  		var valueInterfaces []interface{}
   136  		for _, value := range ids {
   137  			valueInterfaces = append(valueInterfaces, value)
   138  		}
   139  		switch h.Conds.FiatIDs.Op {
   140  		case cruder.IN:
   141  			s.Where(
   142  				sql.In(t.C(entcurrency.FieldFiatID), valueInterfaces...),
   143  			)
   144  		default:
   145  			return fmt.Errorf("invalid fiat currency field op")
   146  		}
   147  	}
   148  
   149  	return nil
   150  }
   151  
   152  func (h *queryHandler) queryJoin() {
   153  	h.stmSelect.Modify(func(s *sql.Selector) {
   154  		h.queryJoinMyself(s)
   155  		if err := h.queryJoinCurrency(s); err != nil {
   156  			return
   157  		}
   158  	})
   159  	if h.stmCount == nil {
   160  		return
   161  	}
   162  	h.stmCount.Modify(func(s *sql.Selector) {
   163  		if err := h.queryJoinCurrency(s); err != nil {
   164  			return
   165  		}
   166  	})
   167  }
   168  
   169  func (h *queryHandler) scan(ctx context.Context) error {
   170  	return h.stmSelect.Scan(ctx, &h.infos)
   171  }
   172  
   173  func (h *queryHandler) formalize() {
   174  	for _, info := range h.infos {
   175  		info.FeedType = basetypes.CurrencyFeedType(basetypes.CurrencyFeedType_value[info.FeedTypeStr])
   176  		if _, err := decimal.NewFromString(info.MarketValueHigh); err != nil {
   177  			info.MarketValueHigh = decimal.NewFromInt(0).String()
   178  		}
   179  		if _, err := decimal.NewFromString(info.MarketValueLow); err != nil {
   180  			info.MarketValueLow = decimal.NewFromInt(0).String()
   181  		}
   182  		if _, err := uuid.Parse(info.EntID); err != nil {
   183  			info.EntID = uuid.Nil.String()
   184  		}
   185  	}
   186  }
   187  
   188  func (h *Handler) GetCurrency(ctx context.Context) (*npool.Currency, error) {
   189  	if h.EntID == nil {
   190  		return nil, fmt.Errorf("invalid entid")
   191  	}
   192  
   193  	handler := &queryHandler{
   194  		Handler: h,
   195  	}
   196  
   197  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   198  		if err := handler.queryFiat(_ctx, cli); err != nil {
   199  			return err
   200  		}
   201  		handler.queryJoin()
   202  		const singleRowLimit = 1
   203  		handler.stmSelect.
   204  			Offset(0).
   205  			Limit(singleRowLimit)
   206  		return handler.scan(_ctx)
   207  	})
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  	if len(handler.infos) == 0 {
   212  		return nil, nil
   213  	}
   214  
   215  	handler.formalize()
   216  	return handler.infos[0], nil
   217  }
   218  
   219  func (h *Handler) GetCurrencies(ctx context.Context) ([]*npool.Currency, uint32, error) {
   220  	handler := &queryHandler{
   221  		Handler: h,
   222  	}
   223  
   224  	var err error
   225  	err = db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   226  		handler.stmSelect, err = handler.queryFiats(cli)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		handler.stmCount, err = handler.queryFiats(cli)
   231  		if err != nil {
   232  			return err
   233  		}
   234  
   235  		handler.queryJoin()
   236  		_total, err := handler.stmCount.Count(_ctx)
   237  		if err != nil {
   238  			return err
   239  		}
   240  		handler.total = uint32(_total)
   241  
   242  		handler.stmSelect.
   243  			Offset(int(h.Offset)).
   244  			Limit(int(h.Limit))
   245  		return handler.scan(_ctx)
   246  	})
   247  	if err != nil {
   248  		return nil, 0, err
   249  	}
   250  
   251  	handler.formalize()
   252  	return handler.infos, handler.total, nil
   253  }