github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/crud/tx/crud.go (about)

     1  package tx
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
     7  	enttran "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/tran"
     8  	"github.com/NpoolPlatform/libent-cruder/pkg/cruder"
     9  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
    10  
    11  	"github.com/google/uuid"
    12  	"github.com/shopspring/decimal"
    13  )
    14  
    15  type Req struct {
    16  	EntID         *uuid.UUID
    17  	CoinTypeID    *uuid.UUID
    18  	FromAccountID *uuid.UUID
    19  	ToAccountID   *uuid.UUID
    20  	Amount        *decimal.Decimal
    21  	FeeAmount     *decimal.Decimal
    22  	ChainTxID     *string
    23  	State         *basetypes.TxState
    24  	Extra         *string
    25  	Type          *basetypes.TxType
    26  }
    27  
    28  func CreateSet(c *ent.TranCreate, req *Req) *ent.TranCreate {
    29  	if req.EntID != nil {
    30  		c.SetEntID(*req.EntID)
    31  	}
    32  	if req.CoinTypeID != nil {
    33  		c.SetCoinTypeID(*req.CoinTypeID)
    34  	}
    35  	if req.FromAccountID != nil {
    36  		c.SetFromAccountID(*req.FromAccountID)
    37  	}
    38  	if req.ToAccountID != nil {
    39  		c.SetToAccountID(*req.ToAccountID)
    40  	}
    41  	if req.Amount != nil {
    42  		c.SetAmount(*req.Amount)
    43  	}
    44  	if req.FeeAmount != nil {
    45  		c.SetFeeAmount(*req.FeeAmount)
    46  	}
    47  	if req.ChainTxID != nil {
    48  		c.SetChainTxID(*req.ChainTxID)
    49  	}
    50  	c.SetState(basetypes.TxState_TxStateCreated.String())
    51  	if req.Extra != nil {
    52  		c.SetExtra(*req.Extra)
    53  	}
    54  	if req.Type != nil {
    55  		c.SetType(req.Type.String())
    56  	}
    57  	return c
    58  }
    59  
    60  func UpdateSet(u *ent.TranUpdateOne, req *Req) (*ent.TranUpdateOne, error) {
    61  	if req.State != nil {
    62  		u = u.SetState(req.State.String())
    63  	}
    64  	if req.ChainTxID != nil {
    65  		u = u.SetChainTxID(*req.ChainTxID)
    66  	}
    67  
    68  	return u, nil
    69  }
    70  
    71  type Conds struct {
    72  	EntID      *cruder.Cond
    73  	CoinTypeID *cruder.Cond
    74  	AccountID  *cruder.Cond
    75  	AccountIDs *cruder.Cond
    76  	State      *cruder.Cond
    77  	Type       *cruder.Cond
    78  	EntIDs     *cruder.Cond
    79  	States     *cruder.Cond
    80  }
    81  
    82  func SetQueryConds(q *ent.TranQuery, conds *Conds) (*ent.TranQuery, error) { //nolint
    83  	if conds.EntID != nil {
    84  		id, ok := conds.EntID.Val.(uuid.UUID)
    85  		if !ok {
    86  			return nil, fmt.Errorf("invalid entid")
    87  		}
    88  		switch conds.EntID.Op {
    89  		case cruder.EQ:
    90  			q.Where(enttran.EntID(id))
    91  		default:
    92  			return nil, fmt.Errorf("invalid tx field")
    93  		}
    94  	}
    95  	if conds.CoinTypeID != nil {
    96  		id, ok := conds.CoinTypeID.Val.(uuid.UUID)
    97  		if !ok {
    98  			return nil, fmt.Errorf("invalid cointypeid")
    99  		}
   100  		switch conds.CoinTypeID.Op {
   101  		case cruder.EQ:
   102  			q.Where(enttran.CoinTypeID(id))
   103  		default:
   104  			return nil, fmt.Errorf("invalid tx field")
   105  		}
   106  	}
   107  	if conds.AccountID != nil {
   108  		id, ok := conds.AccountID.Val.(uuid.UUID)
   109  		if !ok {
   110  			return nil, fmt.Errorf("invalid accountid")
   111  		}
   112  		switch conds.AccountID.Op {
   113  		case cruder.EQ:
   114  			q.Where(
   115  				enttran.Or(
   116  					enttran.FromAccountID(id),
   117  					enttran.ToAccountID(id),
   118  				),
   119  			)
   120  		default:
   121  			return nil, fmt.Errorf("invalid tx field")
   122  		}
   123  	}
   124  	if conds.AccountIDs != nil {
   125  		ids, ok := conds.AccountIDs.Val.([]uuid.UUID)
   126  		if !ok {
   127  			return nil, fmt.Errorf("invalid accountids")
   128  		}
   129  		switch conds.AccountIDs.Op {
   130  		case cruder.IN:
   131  			q.Where(
   132  				enttran.Or(
   133  					enttran.FromAccountIDIn(ids...),
   134  					enttran.ToAccountIDIn(ids...),
   135  				),
   136  			)
   137  		default:
   138  			return nil, fmt.Errorf("invalid tx field")
   139  		}
   140  	}
   141  	if conds.State != nil {
   142  		state, ok := conds.State.Val.(basetypes.TxState)
   143  		if !ok {
   144  			return nil, fmt.Errorf("invalid txstate")
   145  		}
   146  		switch conds.State.Op {
   147  		case cruder.EQ:
   148  			q.Where(enttran.State(state.String()))
   149  		case cruder.NEQ:
   150  			q.Where(enttran.StateNEQ(state.String()))
   151  		default:
   152  			return nil, fmt.Errorf("invalid tx field")
   153  		}
   154  	}
   155  	if conds.Type != nil {
   156  		_type, ok := conds.Type.Val.(basetypes.TxType)
   157  		if !ok {
   158  			return nil, fmt.Errorf("invalid txtype")
   159  		}
   160  		switch conds.Type.Op {
   161  		case cruder.EQ:
   162  			q.Where(enttran.Type(_type.String()))
   163  		default:
   164  			return nil, fmt.Errorf("invalid tx field")
   165  		}
   166  	}
   167  	if conds.EntIDs != nil {
   168  		ids, ok := conds.EntIDs.Val.([]uuid.UUID)
   169  		if !ok {
   170  			return nil, fmt.Errorf("invalid entids")
   171  		}
   172  		switch conds.EntIDs.Op {
   173  		case cruder.IN:
   174  			q.Where(enttran.EntIDIn(ids...))
   175  		default:
   176  			return nil, fmt.Errorf("invalid tx field")
   177  		}
   178  	}
   179  	if conds.States != nil {
   180  		states, ok := conds.States.Val.([]basetypes.TxState)
   181  		if !ok {
   182  			return nil, fmt.Errorf("invalid txstates")
   183  		}
   184  		ss := []string{}
   185  		for _, state := range states {
   186  			ss = append(ss, state.String())
   187  		}
   188  		switch conds.States.Op {
   189  		case cruder.IN:
   190  			q.Where(enttran.StateIn(ss...))
   191  		default:
   192  			return nil, fmt.Errorf("invalid tx field")
   193  		}
   194  	}
   195  	return q, nil
   196  }