github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/chain/query.go (about)

     1  package chain
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"entgo.io/ent/dialect/sql"
     8  	chaincrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/chain"
     9  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    10  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    11  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
    12  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/chain"
    13  
    14  	entchain "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/chainbase"
    15  )
    16  
    17  type queryHandler struct {
    18  	*Handler
    19  	stm   *ent.ChainBaseSelect
    20  	infos []*npool.Chain
    21  	total uint32
    22  }
    23  
    24  func (h *queryHandler) selectChainBase(stm *ent.ChainBaseQuery) {
    25  	h.stm = stm.Select(
    26  		entchain.FieldID,
    27  		entchain.FieldEntID,
    28  		entchain.FieldLogo,
    29  		entchain.FieldNativeUnit,
    30  		entchain.FieldAtomicUnit,
    31  		entchain.FieldUnitExp,
    32  		entchain.FieldEnv,
    33  		entchain.FieldChainID,
    34  		entchain.FieldNickname,
    35  		entchain.FieldGasType,
    36  		entchain.FieldCreatedAt,
    37  		entchain.FieldUpdatedAt,
    38  	).Modify(func(s *sql.Selector) {
    39  		t := sql.Table(entchain.Table)
    40  		s.AppendSelect(
    41  			sql.As(t.C(entchain.FieldName), "chain_type"),
    42  		)
    43  	})
    44  }
    45  
    46  func (h *queryHandler) queryChainBase(cli *ent.Client) error {
    47  	if h.ID == nil && h.EntID == nil {
    48  		return fmt.Errorf("invalid id")
    49  	}
    50  	stm := cli.ChainBase.Query().Where(entchain.DeletedAt(0))
    51  	if h.ID != nil {
    52  		stm.Where(entchain.ID(*h.ID))
    53  	}
    54  	if h.EntID != nil {
    55  		stm.Where(entchain.EntID(*h.EntID))
    56  	}
    57  	h.selectChainBase(stm)
    58  	return nil
    59  }
    60  
    61  func (h *queryHandler) queryChainBases(ctx context.Context, cli *ent.Client) error {
    62  	stm, err := chaincrud.SetQueryConds(cli.ChainBase.Query(), h.Conds)
    63  	if err != nil {
    64  		return err
    65  	}
    66  	total, err := stm.Count(ctx)
    67  	if err != nil {
    68  		return err
    69  	}
    70  	h.total = uint32(total)
    71  	h.selectChainBase(stm)
    72  	return nil
    73  }
    74  
    75  func (h *queryHandler) queryJoin() {
    76  	h.stm.Modify(func(s *sql.Selector) {})
    77  }
    78  
    79  func (h *queryHandler) scan(ctx context.Context) error {
    80  	return h.stm.Scan(ctx, &h.infos)
    81  }
    82  
    83  func (h *queryHandler) formalize() {
    84  	for _, info := range h.infos {
    85  		info.GasType = basetypes.GasType(basetypes.GasType_value[info.GasTypeStr])
    86  	}
    87  }
    88  
    89  func (h *Handler) GetChain(ctx context.Context) (*npool.Chain, error) {
    90  	handler := &queryHandler{
    91  		Handler: h,
    92  	}
    93  
    94  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
    95  		if err := handler.queryChainBase(cli); err != nil {
    96  			return err
    97  		}
    98  		handler.queryJoin()
    99  		const singleRowLimit = 2
   100  		handler.stm.Offset(0).Limit(singleRowLimit)
   101  		return handler.scan(_ctx)
   102  	})
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	if len(handler.infos) == 0 {
   107  		return nil, nil
   108  	}
   109  	if len(handler.infos) > 1 {
   110  		return nil, fmt.Errorf("too many record")
   111  	}
   112  
   113  	handler.formalize()
   114  	return handler.infos[0], nil
   115  }
   116  
   117  func (h *Handler) GetChains(ctx context.Context) ([]*npool.Chain, uint32, error) {
   118  	handler := &queryHandler{
   119  		Handler: h,
   120  	}
   121  
   122  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
   123  		if err := handler.queryChainBases(_ctx, cli); err != nil {
   124  			return err
   125  		}
   126  		handler.queryJoin()
   127  		handler.stm.
   128  			Offset(int(h.Offset)).
   129  			Limit(int(h.Limit))
   130  		return handler.scan(_ctx)
   131  	})
   132  	if err != nil {
   133  		return nil, 0, err
   134  	}
   135  
   136  	handler.formalize()
   137  	return handler.infos, handler.total, nil
   138  }