github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/abci/client/socket_client.go (about)

     1  package abcicli
     2  
     3  import (
     4  	"bufio"
     5  	"container/list"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  
    13  	"github.com/badrootd/nibiru-cometbft/abci/types"
    14  	cmtnet "github.com/badrootd/nibiru-cometbft/libs/net"
    15  	"github.com/badrootd/nibiru-cometbft/libs/service"
    16  	cmtsync "github.com/badrootd/nibiru-cometbft/libs/sync"
    17  	"github.com/badrootd/nibiru-cometbft/libs/timer"
    18  )
    19  
    20  const (
    21  	reqQueueSize    = 256 // TODO make configurable
    22  	flushThrottleMS = 20  // Don't wait longer than...
    23  )
    24  
    25  // This is goroutine-safe, but users should beware that the application in
    26  // general is not meant to be interfaced with concurrent callers.
    27  type socketClient struct {
    28  	service.BaseService
    29  
    30  	addr        string
    31  	mustConnect bool
    32  	conn        net.Conn
    33  
    34  	reqQueue   chan *ReqRes
    35  	flushTimer *timer.ThrottleTimer
    36  
    37  	mtx     cmtsync.Mutex
    38  	err     error
    39  	reqSent *list.List                            // list of requests sent, waiting for response
    40  	resCb   func(*types.Request, *types.Response) // called on all requests, if set.
    41  }
    42  
    43  var _ Client = (*socketClient)(nil)
    44  
    45  // NewSocketClient creates a new socket client, which connects to a given
    46  // address. If mustConnect is true, the client will return an error upon start
    47  // if it fails to connect.
    48  func NewSocketClient(addr string, mustConnect bool) Client {
    49  	cli := &socketClient{
    50  		reqQueue:    make(chan *ReqRes, reqQueueSize),
    51  		flushTimer:  timer.NewThrottleTimer("socketClient", flushThrottleMS),
    52  		mustConnect: mustConnect,
    53  
    54  		addr:    addr,
    55  		reqSent: list.New(),
    56  		resCb:   nil,
    57  	}
    58  	cli.BaseService = *service.NewBaseService(nil, "socketClient", cli)
    59  	return cli
    60  }
    61  
    62  // OnStart implements Service by connecting to the server and spawning reading
    63  // and writing goroutines.
    64  func (cli *socketClient) OnStart() error {
    65  	var (
    66  		err  error
    67  		conn net.Conn
    68  	)
    69  
    70  	for {
    71  		conn, err = cmtnet.Connect(cli.addr)
    72  		if err != nil {
    73  			if cli.mustConnect {
    74  				return err
    75  			}
    76  			cli.Logger.Error(fmt.Sprintf("abci.socketClient failed to connect to %v.  Retrying after %vs...",
    77  				cli.addr, dialRetryIntervalSeconds), "err", err)
    78  			time.Sleep(time.Second * dialRetryIntervalSeconds)
    79  			continue
    80  		}
    81  		cli.conn = conn
    82  
    83  		go cli.sendRequestsRoutine(conn)
    84  		go cli.recvResponseRoutine(conn)
    85  
    86  		return nil
    87  	}
    88  }
    89  
    90  // OnStop implements Service by closing connection and flushing all queues.
    91  func (cli *socketClient) OnStop() {
    92  	if cli.conn != nil {
    93  		cli.conn.Close()
    94  	}
    95  
    96  	cli.flushQueue()
    97  	cli.flushTimer.Stop()
    98  }
    99  
   100  // Error returns an error if the client was stopped abruptly.
   101  func (cli *socketClient) Error() error {
   102  	cli.mtx.Lock()
   103  	defer cli.mtx.Unlock()
   104  	return cli.err
   105  }
   106  
   107  // SetResponseCallback sets a callback, which will be executed for each
   108  // non-error & non-empty response from the server.
   109  //
   110  // NOTE: callback may get internally generated flush responses.
   111  func (cli *socketClient) SetResponseCallback(resCb Callback) {
   112  	cli.mtx.Lock()
   113  	cli.resCb = resCb
   114  	cli.mtx.Unlock()
   115  }
   116  
   117  //----------------------------------------
   118  
   119  func (cli *socketClient) sendRequestsRoutine(conn io.Writer) {
   120  	w := bufio.NewWriter(conn)
   121  	for {
   122  		select {
   123  		case reqres := <-cli.reqQueue:
   124  			// cli.Logger.Debug("Sent request", "requestType", reflect.TypeOf(reqres.Request), "request", reqres.Request)
   125  
   126  			cli.willSendReq(reqres)
   127  			err := types.WriteMessage(reqres.Request, w)
   128  			if err != nil {
   129  				cli.stopForError(fmt.Errorf("write to buffer: %w", err))
   130  				return
   131  			}
   132  
   133  			// If it's a flush request, flush the current buffer.
   134  			if _, ok := reqres.Request.Value.(*types.Request_Flush); ok {
   135  				err = w.Flush()
   136  				if err != nil {
   137  					cli.stopForError(fmt.Errorf("flush buffer: %w", err))
   138  					return
   139  				}
   140  			}
   141  		case <-cli.flushTimer.Ch: // flush queue
   142  			select {
   143  			case cli.reqQueue <- NewReqRes(types.ToRequestFlush()):
   144  			default:
   145  				// Probably will fill the buffer, or retry later.
   146  			}
   147  		case <-cli.Quit():
   148  			return
   149  		}
   150  	}
   151  }
   152  
   153  func (cli *socketClient) recvResponseRoutine(conn io.Reader) {
   154  	r := bufio.NewReader(conn)
   155  	for {
   156  		var res = &types.Response{}
   157  		err := types.ReadMessage(r, res)
   158  		if err != nil {
   159  			cli.stopForError(fmt.Errorf("read message: %w", err))
   160  			return
   161  		}
   162  
   163  		// cli.Logger.Debug("Received response", "responseType", reflect.TypeOf(res), "response", res)
   164  
   165  		switch r := res.Value.(type) {
   166  		case *types.Response_Exception: // app responded with error
   167  			// XXX After setting cli.err, release waiters (e.g. reqres.Done())
   168  			cli.stopForError(errors.New(r.Exception.Error))
   169  			return
   170  		default:
   171  			err := cli.didRecvResponse(res)
   172  			if err != nil {
   173  				cli.stopForError(err)
   174  				return
   175  			}
   176  		}
   177  	}
   178  }
   179  
   180  func (cli *socketClient) willSendReq(reqres *ReqRes) {
   181  	cli.mtx.Lock()
   182  	defer cli.mtx.Unlock()
   183  	cli.reqSent.PushBack(reqres)
   184  }
   185  
   186  func (cli *socketClient) didRecvResponse(res *types.Response) error {
   187  	cli.mtx.Lock()
   188  	defer cli.mtx.Unlock()
   189  
   190  	// Get the first ReqRes.
   191  	next := cli.reqSent.Front()
   192  	if next == nil {
   193  		return fmt.Errorf("unexpected %v when nothing expected", reflect.TypeOf(res.Value))
   194  	}
   195  
   196  	reqres := next.Value.(*ReqRes)
   197  	if !resMatchesReq(reqres.Request, res) {
   198  		return fmt.Errorf("unexpected %v when response to %v expected",
   199  			reflect.TypeOf(res.Value), reflect.TypeOf(reqres.Request.Value))
   200  	}
   201  
   202  	reqres.Response = res
   203  	reqres.Done()            // release waiters
   204  	cli.reqSent.Remove(next) // pop first item from linked list
   205  
   206  	// Notify client listener if set (global callback).
   207  	if cli.resCb != nil {
   208  		cli.resCb(reqres.Request, res)
   209  	}
   210  
   211  	// Notify reqRes listener if set (request specific callback).
   212  	//
   213  	// NOTE: It is possible this callback isn't set on the reqres object. At this
   214  	// point, in which case it will be called after, when it is set.
   215  	reqres.InvokeCallback()
   216  
   217  	return nil
   218  }
   219  
   220  //----------------------------------------
   221  
   222  func (cli *socketClient) EchoAsync(msg string) *ReqRes {
   223  	return cli.queueRequest(types.ToRequestEcho(msg))
   224  }
   225  
   226  func (cli *socketClient) FlushAsync() *ReqRes {
   227  	return cli.queueRequest(types.ToRequestFlush())
   228  }
   229  
   230  func (cli *socketClient) InfoAsync(req types.RequestInfo) *ReqRes {
   231  	return cli.queueRequest(types.ToRequestInfo(req))
   232  }
   233  
   234  func (cli *socketClient) DeliverTxAsync(req types.RequestDeliverTx) *ReqRes {
   235  	return cli.queueRequest(types.ToRequestDeliverTx(req))
   236  }
   237  
   238  func (cli *socketClient) CheckTxAsync(req types.RequestCheckTx) *ReqRes {
   239  	return cli.queueRequest(types.ToRequestCheckTx(req))
   240  }
   241  
   242  func (cli *socketClient) QueryAsync(req types.RequestQuery) *ReqRes {
   243  	return cli.queueRequest(types.ToRequestQuery(req))
   244  }
   245  
   246  func (cli *socketClient) CommitAsync() *ReqRes {
   247  	return cli.queueRequest(types.ToRequestCommit())
   248  }
   249  
   250  func (cli *socketClient) InitChainAsync(req types.RequestInitChain) *ReqRes {
   251  	return cli.queueRequest(types.ToRequestInitChain(req))
   252  }
   253  
   254  func (cli *socketClient) BeginBlockAsync(req types.RequestBeginBlock) *ReqRes {
   255  	return cli.queueRequest(types.ToRequestBeginBlock(req))
   256  }
   257  
   258  func (cli *socketClient) EndBlockAsync(req types.RequestEndBlock) *ReqRes {
   259  	return cli.queueRequest(types.ToRequestEndBlock(req))
   260  }
   261  
   262  func (cli *socketClient) ListSnapshotsAsync(req types.RequestListSnapshots) *ReqRes {
   263  	return cli.queueRequest(types.ToRequestListSnapshots(req))
   264  }
   265  
   266  func (cli *socketClient) OfferSnapshotAsync(req types.RequestOfferSnapshot) *ReqRes {
   267  	return cli.queueRequest(types.ToRequestOfferSnapshot(req))
   268  }
   269  
   270  func (cli *socketClient) LoadSnapshotChunkAsync(req types.RequestLoadSnapshotChunk) *ReqRes {
   271  	return cli.queueRequest(types.ToRequestLoadSnapshotChunk(req))
   272  }
   273  
   274  func (cli *socketClient) ApplySnapshotChunkAsync(req types.RequestApplySnapshotChunk) *ReqRes {
   275  	return cli.queueRequest(types.ToRequestApplySnapshotChunk(req))
   276  }
   277  
   278  func (cli *socketClient) PrepareProposalAsync(req types.RequestPrepareProposal) *ReqRes {
   279  	return cli.queueRequest(types.ToRequestPrepareProposal(req))
   280  }
   281  
   282  func (cli *socketClient) ProcessProposalAsync(req types.RequestProcessProposal) *ReqRes {
   283  	return cli.queueRequest(types.ToRequestProcessProposal(req))
   284  }
   285  
   286  //----------------------------------------
   287  
   288  func (cli *socketClient) FlushSync() error {
   289  	reqRes := cli.queueRequest(types.ToRequestFlush())
   290  	if err := cli.Error(); err != nil {
   291  		return err
   292  	}
   293  	reqRes.Wait() // NOTE: if we don't flush the queue, its possible to get stuck here
   294  	return cli.Error()
   295  }
   296  
   297  func (cli *socketClient) EchoSync(msg string) (*types.ResponseEcho, error) {
   298  	reqres := cli.queueRequest(types.ToRequestEcho(msg))
   299  	if err := cli.FlushSync(); err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	return reqres.Response.GetEcho(), cli.Error()
   304  }
   305  
   306  func (cli *socketClient) InfoSync(req types.RequestInfo) (*types.ResponseInfo, error) {
   307  	reqres := cli.queueRequest(types.ToRequestInfo(req))
   308  	if err := cli.FlushSync(); err != nil {
   309  		return nil, err
   310  	}
   311  
   312  	return reqres.Response.GetInfo(), cli.Error()
   313  }
   314  
   315  func (cli *socketClient) DeliverTxSync(req types.RequestDeliverTx) (*types.ResponseDeliverTx, error) {
   316  	reqres := cli.queueRequest(types.ToRequestDeliverTx(req))
   317  	if err := cli.FlushSync(); err != nil {
   318  		return nil, err
   319  	}
   320  
   321  	return reqres.Response.GetDeliverTx(), cli.Error()
   322  }
   323  
   324  func (cli *socketClient) CheckTxSync(req types.RequestCheckTx) (*types.ResponseCheckTx, error) {
   325  	reqres := cli.queueRequest(types.ToRequestCheckTx(req))
   326  	if err := cli.FlushSync(); err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	return reqres.Response.GetCheckTx(), cli.Error()
   331  }
   332  
   333  func (cli *socketClient) QuerySync(req types.RequestQuery) (*types.ResponseQuery, error) {
   334  	reqres := cli.queueRequest(types.ToRequestQuery(req))
   335  	if err := cli.FlushSync(); err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	return reqres.Response.GetQuery(), cli.Error()
   340  }
   341  
   342  func (cli *socketClient) CommitSync() (*types.ResponseCommit, error) {
   343  	reqres := cli.queueRequest(types.ToRequestCommit())
   344  	if err := cli.FlushSync(); err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	return reqres.Response.GetCommit(), cli.Error()
   349  }
   350  
   351  func (cli *socketClient) InitChainSync(req types.RequestInitChain) (*types.ResponseInitChain, error) {
   352  	reqres := cli.queueRequest(types.ToRequestInitChain(req))
   353  	if err := cli.FlushSync(); err != nil {
   354  		return nil, err
   355  	}
   356  
   357  	return reqres.Response.GetInitChain(), cli.Error()
   358  }
   359  
   360  func (cli *socketClient) BeginBlockSync(req types.RequestBeginBlock) (*types.ResponseBeginBlock, error) {
   361  	reqres := cli.queueRequest(types.ToRequestBeginBlock(req))
   362  	if err := cli.FlushSync(); err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	return reqres.Response.GetBeginBlock(), cli.Error()
   367  }
   368  
   369  func (cli *socketClient) EndBlockSync(req types.RequestEndBlock) (*types.ResponseEndBlock, error) {
   370  	reqres := cli.queueRequest(types.ToRequestEndBlock(req))
   371  	if err := cli.FlushSync(); err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	return reqres.Response.GetEndBlock(), cli.Error()
   376  }
   377  
   378  func (cli *socketClient) ListSnapshotsSync(req types.RequestListSnapshots) (*types.ResponseListSnapshots, error) {
   379  	reqres := cli.queueRequest(types.ToRequestListSnapshots(req))
   380  	if err := cli.FlushSync(); err != nil {
   381  		return nil, err
   382  	}
   383  
   384  	return reqres.Response.GetListSnapshots(), cli.Error()
   385  }
   386  
   387  func (cli *socketClient) OfferSnapshotSync(req types.RequestOfferSnapshot) (*types.ResponseOfferSnapshot, error) {
   388  	reqres := cli.queueRequest(types.ToRequestOfferSnapshot(req))
   389  	if err := cli.FlushSync(); err != nil {
   390  		return nil, err
   391  	}
   392  
   393  	return reqres.Response.GetOfferSnapshot(), cli.Error()
   394  }
   395  
   396  func (cli *socketClient) LoadSnapshotChunkSync(
   397  	req types.RequestLoadSnapshotChunk) (*types.ResponseLoadSnapshotChunk, error) {
   398  	reqres := cli.queueRequest(types.ToRequestLoadSnapshotChunk(req))
   399  	if err := cli.FlushSync(); err != nil {
   400  		return nil, err
   401  	}
   402  
   403  	return reqres.Response.GetLoadSnapshotChunk(), cli.Error()
   404  }
   405  
   406  func (cli *socketClient) ApplySnapshotChunkSync(
   407  	req types.RequestApplySnapshotChunk) (*types.ResponseApplySnapshotChunk, error) {
   408  	reqres := cli.queueRequest(types.ToRequestApplySnapshotChunk(req))
   409  	if err := cli.FlushSync(); err != nil {
   410  		return nil, err
   411  	}
   412  	return reqres.Response.GetApplySnapshotChunk(), cli.Error()
   413  }
   414  
   415  func (cli *socketClient) PrepareProposalSync(req types.RequestPrepareProposal) (*types.ResponsePrepareProposal, error) {
   416  	reqres := cli.queueRequest(types.ToRequestPrepareProposal(req))
   417  	if err := cli.FlushSync(); err != nil {
   418  		return nil, err
   419  	}
   420  
   421  	return reqres.Response.GetPrepareProposal(), cli.Error()
   422  }
   423  
   424  func (cli *socketClient) ProcessProposalSync(req types.RequestProcessProposal) (*types.ResponseProcessProposal, error) {
   425  	reqres := cli.queueRequest(types.ToRequestProcessProposal(req))
   426  	if err := cli.FlushSync(); err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	return reqres.Response.GetProcessProposal(), cli.Error()
   431  }
   432  
   433  //----------------------------------------
   434  
   435  func (cli *socketClient) queueRequest(req *types.Request) *ReqRes {
   436  	reqres := NewReqRes(req)
   437  
   438  	// TODO: set cli.err if reqQueue times out
   439  	cli.reqQueue <- reqres
   440  
   441  	// Maybe auto-flush, or unset auto-flush
   442  	switch req.Value.(type) {
   443  	case *types.Request_Flush:
   444  		cli.flushTimer.Unset()
   445  	default:
   446  		cli.flushTimer.Set()
   447  	}
   448  
   449  	return reqres
   450  }
   451  
   452  func (cli *socketClient) flushQueue() {
   453  	cli.mtx.Lock()
   454  	defer cli.mtx.Unlock()
   455  
   456  	// mark all in-flight messages as resolved (they will get cli.Error())
   457  	for req := cli.reqSent.Front(); req != nil; req = req.Next() {
   458  		reqres := req.Value.(*ReqRes)
   459  		reqres.Done()
   460  	}
   461  
   462  	// mark all queued messages as resolved
   463  LOOP:
   464  	for {
   465  		select {
   466  		case reqres := <-cli.reqQueue:
   467  			reqres.Done()
   468  		default:
   469  			break LOOP
   470  		}
   471  	}
   472  }
   473  
   474  //----------------------------------------
   475  
   476  func resMatchesReq(req *types.Request, res *types.Response) (ok bool) {
   477  	switch req.Value.(type) {
   478  	case *types.Request_Echo:
   479  		_, ok = res.Value.(*types.Response_Echo)
   480  	case *types.Request_Flush:
   481  		_, ok = res.Value.(*types.Response_Flush)
   482  	case *types.Request_Info:
   483  		_, ok = res.Value.(*types.Response_Info)
   484  	case *types.Request_DeliverTx:
   485  		_, ok = res.Value.(*types.Response_DeliverTx)
   486  	case *types.Request_CheckTx:
   487  		_, ok = res.Value.(*types.Response_CheckTx)
   488  	case *types.Request_Commit:
   489  		_, ok = res.Value.(*types.Response_Commit)
   490  	case *types.Request_Query:
   491  		_, ok = res.Value.(*types.Response_Query)
   492  	case *types.Request_InitChain:
   493  		_, ok = res.Value.(*types.Response_InitChain)
   494  	case *types.Request_BeginBlock:
   495  		_, ok = res.Value.(*types.Response_BeginBlock)
   496  	case *types.Request_EndBlock:
   497  		_, ok = res.Value.(*types.Response_EndBlock)
   498  	case *types.Request_ApplySnapshotChunk:
   499  		_, ok = res.Value.(*types.Response_ApplySnapshotChunk)
   500  	case *types.Request_LoadSnapshotChunk:
   501  		_, ok = res.Value.(*types.Response_LoadSnapshotChunk)
   502  	case *types.Request_ListSnapshots:
   503  		_, ok = res.Value.(*types.Response_ListSnapshots)
   504  	case *types.Request_OfferSnapshot:
   505  		_, ok = res.Value.(*types.Response_OfferSnapshot)
   506  	case *types.Request_PrepareProposal:
   507  		_, ok = res.Value.(*types.Response_PrepareProposal)
   508  	case *types.Request_ProcessProposal:
   509  		_, ok = res.Value.(*types.Response_ProcessProposal)
   510  	}
   511  	return ok
   512  }
   513  
   514  func (cli *socketClient) stopForError(err error) {
   515  	if !cli.IsRunning() {
   516  		return
   517  	}
   518  
   519  	cli.mtx.Lock()
   520  	if cli.err == nil {
   521  		cli.err = err
   522  	}
   523  	cli.mtx.Unlock()
   524  
   525  	cli.Logger.Error(fmt.Sprintf("Stopping abci.socketClient for error: %v", err.Error()))
   526  	if err := cli.Stop(); err != nil {
   527  		cli.Logger.Error("Error stopping abci.socketClient", "err", err)
   528  	}
   529  }