github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/tx/query.go (about)

     1  package tx
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     8  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/tx"
     9  
    10  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    12  	enttx "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/tran"
    13  
    14  	txcrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/tx"
    15  	entcoinbase "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/coinbase"
    16  
    17  	"entgo.io/ent/dialect/sql"
    18  	"github.com/shopspring/decimal"
    19  )
    20  
    21  type queryHandler struct {
    22  	*Handler
    23  	stm   *ent.TranSelect
    24  	infos []*npool.Tx
    25  	total uint32
    26  }
    27  
    28  func (h *queryHandler) selectTx(stm *ent.TranQuery) {
    29  	h.stm = stm.Select(
    30  		enttx.FieldID,
    31  		enttx.FieldEntID,
    32  		enttx.FieldCoinTypeID,
    33  		enttx.FieldFromAccountID,
    34  		enttx.FieldToAccountID,
    35  		enttx.FieldAmount,
    36  		enttx.FieldFeeAmount,
    37  		enttx.FieldState,
    38  		enttx.FieldChainTxID,
    39  		enttx.FieldType,
    40  		enttx.FieldExtra,
    41  		enttx.FieldCreatedAt,
    42  		enttx.FieldUpdatedAt,
    43  	)
    44  }
    45  
    46  func (h *queryHandler) queryTx(cli *ent.Client) error {
    47  	if h.ID == nil && h.EntID == nil {
    48  		return fmt.Errorf("invalid id")
    49  	}
    50  	stm := cli.Tran.Query().Where(enttx.DeletedAt(0))
    51  	if h.ID != nil {
    52  		stm.Where(enttx.ID(*h.ID))
    53  	}
    54  	if h.EntID != nil {
    55  		stm.Where(enttx.EntID(*h.EntID))
    56  	}
    57  	h.selectTx(stm)
    58  	return nil
    59  }
    60  
    61  func (h *queryHandler) queryTxs(ctx context.Context, cli *ent.Client) error {
    62  	stm, err := txcrud.SetQueryConds(cli.Tran.Query(), h.Conds)
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	total, err := stm.Count(ctx)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	h.total = uint32(total)
    72  
    73  	h.selectTx(stm)
    74  	return nil
    75  }
    76  
    77  func (h *queryHandler) queryJoinCoin(s *sql.Selector) {
    78  	t := sql.Table(entcoinbase.Table)
    79  	s.
    80  		LeftJoin(t).
    81  		On(
    82  			s.C(enttx.FieldCoinTypeID),
    83  			t.C(entcoinbase.FieldEntID),
    84  		).
    85  		AppendSelect(
    86  			sql.As(t.C(entcoinbase.FieldName), "coin_name"),
    87  			sql.As(t.C(entcoinbase.FieldLogo), "coin_logo"),
    88  			sql.As(t.C(entcoinbase.FieldUnit), "coin_unit"),
    89  			sql.As(t.C(entcoinbase.FieldEnv), "coin_env"),
    90  		)
    91  }
    92  
    93  func (h *queryHandler) queryJoin() {
    94  	h.stm.Modify(func(s *sql.Selector) {
    95  		h.queryJoinCoin(s)
    96  	})
    97  }
    98  
    99  func (h *queryHandler) scan(ctx context.Context) error {
   100  	return h.stm.Scan(ctx, &h.infos)
   101  }
   102  
   103  func (h *queryHandler) formalize() {
   104  	for _, info := range h.infos {
   105  		info.Type = basetypes.TxType(basetypes.TxType_value[info.TypeStr])
   106  		info.State = basetypes.TxState(basetypes.TxState_value[info.StateStr])
   107  		amount, err := decimal.NewFromString(info.Amount)
   108  		if err != nil {
   109  			info.Amount = decimal.NewFromInt(0).String()
   110  		} else {
   111  			info.Amount = amount.String()
   112  		}
   113  		amount, err = decimal.NewFromString(info.FeeAmount)
   114  		if err != nil {
   115  			info.FeeAmount = decimal.NewFromInt(0).String()
   116  		} else {
   117  			info.FeeAmount = amount.String()
   118  		}
   119  	}
   120  }
   121  
   122  func (h *Handler) GetTx(ctx context.Context) (*npool.Tx, error) {
   123  	handler := &queryHandler{
   124  		Handler: h,
   125  	}
   126  
   127  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   128  		if err := handler.queryTx(cli); err != nil {
   129  			return err
   130  		}
   131  		handler.queryJoin()
   132  		const singleRowLimit = 2
   133  		handler.stm.Offset(0).Limit(singleRowLimit)
   134  		return handler.scan(_ctx)
   135  	})
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	if len(handler.infos) == 0 {
   140  		return nil, nil
   141  	}
   142  	if len(handler.infos) > 1 {
   143  		return nil, fmt.Errorf("too many record")
   144  	}
   145  
   146  	handler.formalize()
   147  	return handler.infos[0], nil
   148  }
   149  
   150  func (h *Handler) GetTxs(ctx context.Context) ([]*npool.Tx, uint32, error) {
   151  	handler := &queryHandler{
   152  		Handler: h,
   153  	}
   154  
   155  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   156  		if err := handler.queryTxs(ctx, cli); err != nil {
   157  			return err
   158  		}
   159  		handler.queryJoin()
   160  		handler.stm.
   161  			Offset(int(h.Offset)).
   162  			Limit(int(h.Limit)).
   163  			Order(ent.Desc(enttx.FieldUpdatedAt))
   164  		return handler.scan(_ctx)
   165  	})
   166  	if err != nil {
   167  		return nil, 0, err
   168  	}
   169  
   170  	handler.formalize()
   171  	return handler.infos, handler.total, nil
   172  }