github.com/NpoolPlatform/chain-middleware@v0.0.0-20240228100535-eb1bcf896eb9/pkg/mw/tx/update.go (about)

     1  package tx
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	txcrud "github.com/NpoolPlatform/chain-middleware/pkg/crud/tx"
     8  	basetypes "github.com/NpoolPlatform/message/npool/basetypes/v1"
     9  	npool "github.com/NpoolPlatform/message/npool/chain/mw/v1/tx"
    10  
    11  	"github.com/NpoolPlatform/chain-middleware/pkg/db"
    12  	"github.com/NpoolPlatform/chain-middleware/pkg/db/ent"
    13  	enttran "github.com/NpoolPlatform/chain-middleware/pkg/db/ent/tran"
    14  )
    15  
    16  type updateHandler struct {
    17  	*Handler
    18  }
    19  
    20  //nolint:gocyclo
    21  func (h *updateHandler) validateState(info *ent.Tran) error {
    22  	if h.State == nil {
    23  		return nil
    24  	}
    25  
    26  	switch info.State {
    27  	case basetypes.TxState_TxStateCreated.String():
    28  		switch *h.State {
    29  		case basetypes.TxState_TxStateCreatedCheck:
    30  		default:
    31  			return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    32  		}
    33  	case basetypes.TxState_TxStateCreatedCheck.String():
    34  		switch *h.State {
    35  		case basetypes.TxState_TxStateWait:
    36  		default:
    37  			return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    38  		}
    39  	case basetypes.TxState_TxStateWait.String():
    40  		switch *h.State {
    41  		case basetypes.TxState_TxStateWaitCheck:
    42  		default:
    43  			return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    44  		}
    45  	case basetypes.TxState_TxStateWaitCheck.String():
    46  		switch *h.State {
    47  		case basetypes.TxState_TxStateTransferring:
    48  		case basetypes.TxState_TxStateFail:
    49  		default:
    50  			return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    51  		}
    52  	case basetypes.TxState_TxStateTransferring.String():
    53  		switch *h.State {
    54  		case basetypes.TxState_TxStateSuccessful:
    55  		case basetypes.TxState_TxStateFail:
    56  		default:
    57  			return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    58  		}
    59  	case basetypes.TxState_TxStateSuccessful.String():
    60  		fallthrough //nolint
    61  	case basetypes.TxState_TxStateFail.String():
    62  		fallthrough //nolint
    63  	default:
    64  		return fmt.Errorf("state is invalid: %v -> %v", info.State, h.State)
    65  	}
    66  
    67  	return nil
    68  }
    69  
    70  func (h *Handler) UpdateTx(ctx context.Context) (*npool.Tx, error) {
    71  	if h.ID == nil {
    72  		return nil, fmt.Errorf("invalid id")
    73  	}
    74  
    75  	handler := &updateHandler{
    76  		Handler: h,
    77  	}
    78  
    79  	err := db.WithClient(ctx, func(_ctx context.Context, cli *ent.Client) error {
    80  		info, err := cli.
    81  			Tran.
    82  			Query().
    83  			Where(
    84  				enttran.ID(*h.ID),
    85  			).
    86  			Only(_ctx)
    87  		if err != nil {
    88  			return err
    89  		}
    90  
    91  		if err := handler.validateState(info); err != nil {
    92  			return err
    93  		}
    94  
    95  		stm, err := txcrud.UpdateSet(
    96  			info.Update(),
    97  			&txcrud.Req{
    98  				ChainTxID: h.ChainTxID,
    99  				State:     h.State,
   100  				Extra:     h.Extra,
   101  			},
   102  		)
   103  		if err != nil {
   104  			return err
   105  		}
   106  		if _, err := stm.Save(_ctx); err != nil {
   107  			return err
   108  		}
   109  		return nil
   110  	})
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	return h.GetTx(ctx)
   116  }