github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/abci/client/socket_client.go (about)

     1  package abcicli
     2  
     3  import (
     4  	"bufio"
     5  	"container/list"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  
    13  	"github.com/badrootd/celestia-core/abci/types"
    14  	cmtnet "github.com/badrootd/celestia-core/libs/net"
    15  	"github.com/badrootd/celestia-core/libs/service"
    16  	cmtsync "github.com/badrootd/celestia-core/libs/sync"
    17  	"github.com/badrootd/celestia-core/libs/timer"
    18  )
    19  
    20  const (
    21  	reqQueueSize    = 256 // TODO make configurable
    22  	flushThrottleMS = 20  // Don't wait longer than...
    23  )
    24  
    25  // This is goroutine-safe, but users should beware that the application in
    26  // general is not meant to be interfaced with concurrent callers.
    27  type socketClient struct {
    28  	service.BaseService
    29  
    30  	addr        string
    31  	mustConnect bool
    32  	conn        net.Conn
    33  
    34  	reqQueue   chan *ReqRes
    35  	flushTimer *timer.ThrottleTimer
    36  
    37  	mtx     cmtsync.Mutex
    38  	err     error
    39  	reqSent *list.List                            // list of requests sent, waiting for response
    40  	resCb   func(*types.Request, *types.Response) // called on all requests, if set.
    41  }
    42  
    43  var _ Client = (*socketClient)(nil)
    44  
    45  // NewSocketClient creates a new socket client, which connects to a given
    46  // address. If mustConnect is true, the client will return an error upon start
    47  // if it fails to connect.
    48  func NewSocketClient(addr string, mustConnect bool) Client {
    49  	cli := &socketClient{
    50  		reqQueue:    make(chan *ReqRes, reqQueueSize),
    51  		flushTimer:  timer.NewThrottleTimer("socketClient", flushThrottleMS),
    52  		mustConnect: mustConnect,
    53  
    54  		addr:    addr,
    55  		reqSent: list.New(),
    56  		resCb:   nil,
    57  	}
    58  	cli.BaseService = *service.NewBaseService(nil, "socketClient", cli)
    59  	return cli
    60  }
    61  
    62  // OnStart implements Service by connecting to the server and spawning reading
    63  // and writing goroutines.
    64  func (cli *socketClient) OnStart() error {
    65  	var (
    66  		err  error
    67  		conn net.Conn
    68  	)
    69  
    70  	for {
    71  		conn, err = cmtnet.Connect(cli.addr)
    72  		if err != nil {
    73  			if cli.mustConnect {
    74  				return err
    75  			}
    76  			cli.Logger.Error(fmt.Sprintf("abci.socketClient failed to connect to %v.  Retrying after %vs...",
    77  				cli.addr, dialRetryIntervalSeconds), "err", err)
    78  			time.Sleep(time.Second * dialRetryIntervalSeconds)
    79  			continue
    80  		}
    81  		cli.conn = conn
    82  
    83  		go cli.sendRequestsRoutine(conn)
    84  		go cli.recvResponseRoutine(conn)
    85  
    86  		return nil
    87  	}
    88  }
    89  
    90  // OnStop implements Service by closing connection and flushing all queues.
    91  func (cli *socketClient) OnStop() {
    92  	if cli.conn != nil {
    93  		cli.conn.Close()
    94  	}
    95  
    96  	cli.flushQueue()
    97  	cli.flushTimer.Stop()
    98  }
    99  
   100  // Error returns an error if the client was stopped abruptly.
   101  func (cli *socketClient) Error() error {
   102  	cli.mtx.Lock()
   103  	defer cli.mtx.Unlock()
   104  	return cli.err
   105  }
   106  
   107  // SetResponseCallback sets a callback, which will be executed for each
   108  // non-error & non-empty response from the server.
   109  //
   110  // NOTE: callback may get internally generated flush responses.
   111  func (cli *socketClient) SetResponseCallback(resCb Callback) {
   112  	cli.mtx.Lock()
   113  	cli.resCb = resCb
   114  	cli.mtx.Unlock()
   115  }
   116  
   117  //----------------------------------------
   118  
   119  func (cli *socketClient) sendRequestsRoutine(conn io.Writer) {
   120  	w := bufio.NewWriter(conn)
   121  	for {
   122  		select {
   123  		case reqres := <-cli.reqQueue:
   124  			// cli.Logger.Debug("Sent request", "requestType", reflect.TypeOf(reqres.Request), "request", reqres.Request)
   125  
   126  			cli.willSendReq(reqres)
   127  			err := types.WriteMessage(reqres.Request, w)
   128  			if err != nil {
   129  				cli.stopForError(fmt.Errorf("write to buffer: %w", err))
   130  				return
   131  			}
   132  
   133  			// If it's a flush request, flush the current buffer.
   134  			if _, ok := reqres.Request.Value.(*types.Request_Flush); ok {
   135  				err = w.Flush()
   136  				if err != nil {
   137  					cli.stopForError(fmt.Errorf("flush buffer: %w", err))
   138  					return
   139  				}
   140  			}
   141  		case <-cli.flushTimer.Ch: // flush queue
   142  			select {
   143  			case cli.reqQueue <- NewReqRes(types.ToRequestFlush()):
   144  			default:
   145  				// Probably will fill the buffer, or retry later.
   146  			}
   147  		case <-cli.Quit():
   148  			return
   149  		}
   150  	}
   151  }
   152  
   153  func (cli *socketClient) recvResponseRoutine(conn io.Reader) {
   154  	r := bufio.NewReader(conn)
   155  	for {
   156  		var res = &types.Response{}
   157  		err := types.ReadMessage(r, res)
   158  		if err != nil {
   159  			cli.stopForError(fmt.Errorf("read message: %w", err))
   160  			return
   161  		}
   162  
   163  		// cli.Logger.Debug("Received response", "responseType", reflect.TypeOf(res), "response", res)
   164  
   165  		switch r := res.Value.(type) {
   166  		case *types.Response_Exception: // app responded with error
   167  			// XXX After setting cli.err, release waiters (e.g. reqres.Done())
   168  			cli.stopForError(errors.New(r.Exception.Error))
   169  			return
   170  		default:
   171  			err := cli.didRecvResponse(res)
   172  			if err != nil {
   173  				cli.stopForError(err)
   174  				return
   175  			}
   176  		}
   177  	}
   178  }
   179  
   180  func (cli *socketClient) willSendReq(reqres *ReqRes) {
   181  	cli.mtx.Lock()
   182  	defer cli.mtx.Unlock()
   183  	cli.reqSent.PushBack(reqres)
   184  }
   185  
   186  func (cli *socketClient) didRecvResponse(res *types.Response) error {
   187  	cli.mtx.Lock()
   188  	defer cli.mtx.Unlock()
   189  
   190  	// Get the first ReqRes.
   191  	next := cli.reqSent.Front()
   192  	if next == nil {
   193  		return fmt.Errorf("unexpected %v when nothing expected", reflect.TypeOf(res.Value))
   194  	}
   195  
   196  	reqres := next.Value.(*ReqRes)
   197  	if !resMatchesReq(reqres.Request, res) {
   198  		return fmt.Errorf("unexpected %v when response to %v expected",
   199  			reflect.TypeOf(res.Value), reflect.TypeOf(reqres.Request.Value))
   200  	}
   201  
   202  	reqres.Response = res
   203  	reqres.Done()            // release waiters
   204  	cli.reqSent.Remove(next) // pop first item from linked list
   205  
   206  	// Notify client listener if set (global callback).
   207  	if cli.resCb != nil {
   208  		cli.resCb(reqres.Request, res)
   209  	}
   210  
   211  	// Notify reqRes listener if set (request specific callback).
   212  	//
   213  	// NOTE: It is possible this callback isn't set on the reqres object. At this
   214  	// point, in which case it will be called after, when it is set.
   215  	reqres.InvokeCallback()
   216  
   217  	return nil
   218  }
   219  
   220  //----------------------------------------
   221  
   222  func (cli *socketClient) EchoAsync(msg string) *ReqRes {
   223  	return cli.queueRequest(types.ToRequestEcho(msg))
   224  }
   225  
   226  func (cli *socketClient) FlushAsync() *ReqRes {
   227  	return cli.queueRequest(types.ToRequestFlush())
   228  }
   229  
   230  func (cli *socketClient) InfoAsync(req types.RequestInfo) *ReqRes {
   231  	return cli.queueRequest(types.ToRequestInfo(req))
   232  }
   233  
   234  func (cli *socketClient) DeliverTxAsync(req types.RequestDeliverTx) *ReqRes {
   235  	return cli.queueRequest(types.ToRequestDeliverTx(req))
   236  }
   237  
   238  func (cli *socketClient) CheckTxAsync(req types.RequestCheckTx) *ReqRes {
   239  	return cli.queueRequest(types.ToRequestCheckTx(req))
   240  }
   241  
   242  func (cli *socketClient) QueryAsync(req types.RequestQuery) *ReqRes {
   243  	return cli.queueRequest(types.ToRequestQuery(req))
   244  }
   245  
   246  func (cli *socketClient) CommitAsync() *ReqRes {
   247  	return cli.queueRequest(types.ToRequestCommit())
   248  }
   249  
   250  func (cli *socketClient) InitChainAsync(req types.RequestInitChain) *ReqRes {
   251  	return cli.queueRequest(types.ToRequestInitChain(req))
   252  }
   253  
   254  func (cli *socketClient) BeginBlockAsync(req types.RequestBeginBlock) *ReqRes {
   255  	return cli.queueRequest(types.ToRequestBeginBlock(req))
   256  }
   257  
   258  func (cli *socketClient) EndBlockAsync(req types.RequestEndBlock) *ReqRes {
   259  	return cli.queueRequest(types.ToRequestEndBlock(req))
   260  }
   261  
   262  func (cli *socketClient) ListSnapshotsAsync(req types.RequestListSnapshots) *ReqRes {
   263  	return cli.queueRequest(types.ToRequestListSnapshots(req))
   264  }
   265  
   266  func (cli *socketClient) OfferSnapshotAsync(req types.RequestOfferSnapshot) *ReqRes {
   267  	return cli.queueRequest(types.ToRequestOfferSnapshot(req))
   268  }
   269  
   270  func (cli *socketClient) LoadSnapshotChunkAsync(req types.RequestLoadSnapshotChunk) *ReqRes {
   271  	return cli.queueRequest(types.ToRequestLoadSnapshotChunk(req))
   272  }
   273  
   274  func (cli *socketClient) ApplySnapshotChunkAsync(req types.RequestApplySnapshotChunk) *ReqRes {
   275  	return cli.queueRequest(types.ToRequestApplySnapshotChunk(req))
   276  }
   277  
   278  func (cli *socketClient) PrepareProposalAsync(
   279  	req types.RequestPrepareProposal,
   280  ) *ReqRes {
   281  	return cli.queueRequest(types.ToRequestPrepareProposal(req))
   282  }
   283  
   284  func (cli *socketClient) ProcessProposalAsync(
   285  	req types.RequestProcessProposal,
   286  ) *ReqRes {
   287  	return cli.queueRequest(types.ToRequestProcessProposal(req))
   288  }
   289  
   290  func (cli *socketClient) FlushSync() error {
   291  	reqRes := cli.queueRequest(types.ToRequestFlush())
   292  	if err := cli.Error(); err != nil {
   293  		return err
   294  	}
   295  	reqRes.Wait() // NOTE: if we don't flush the queue, its possible to get stuck here
   296  	return cli.Error()
   297  }
   298  
   299  func (cli *socketClient) EchoSync(msg string) (*types.ResponseEcho, error) {
   300  	reqres := cli.queueRequest(types.ToRequestEcho(msg))
   301  	if err := cli.FlushSync(); err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	return reqres.Response.GetEcho(), cli.Error()
   306  }
   307  
   308  func (cli *socketClient) InfoSync(req types.RequestInfo) (*types.ResponseInfo, error) {
   309  	reqres := cli.queueRequest(types.ToRequestInfo(req))
   310  	if err := cli.FlushSync(); err != nil {
   311  		return nil, err
   312  	}
   313  
   314  	return reqres.Response.GetInfo(), cli.Error()
   315  }
   316  
   317  func (cli *socketClient) DeliverTxSync(req types.RequestDeliverTx) (*types.ResponseDeliverTx, error) {
   318  	reqres := cli.queueRequest(types.ToRequestDeliverTx(req))
   319  	if err := cli.FlushSync(); err != nil {
   320  		return nil, err
   321  	}
   322  
   323  	return reqres.Response.GetDeliverTx(), cli.Error()
   324  }
   325  
   326  func (cli *socketClient) CheckTxSync(req types.RequestCheckTx) (*types.ResponseCheckTx, error) {
   327  	reqres := cli.queueRequest(types.ToRequestCheckTx(req))
   328  	if err := cli.FlushSync(); err != nil {
   329  		return nil, err
   330  	}
   331  
   332  	return reqres.Response.GetCheckTx(), cli.Error()
   333  }
   334  
   335  func (cli *socketClient) QuerySync(req types.RequestQuery) (*types.ResponseQuery, error) {
   336  	reqres := cli.queueRequest(types.ToRequestQuery(req))
   337  	if err := cli.FlushSync(); err != nil {
   338  		return nil, err
   339  	}
   340  
   341  	return reqres.Response.GetQuery(), cli.Error()
   342  }
   343  
   344  func (cli *socketClient) CommitSync() (*types.ResponseCommit, error) {
   345  	reqres := cli.queueRequest(types.ToRequestCommit())
   346  	if err := cli.FlushSync(); err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	return reqres.Response.GetCommit(), cli.Error()
   351  }
   352  
   353  func (cli *socketClient) InitChainSync(req types.RequestInitChain) (*types.ResponseInitChain, error) {
   354  	reqres := cli.queueRequest(types.ToRequestInitChain(req))
   355  	if err := cli.FlushSync(); err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	return reqres.Response.GetInitChain(), cli.Error()
   360  }
   361  
   362  func (cli *socketClient) BeginBlockSync(req types.RequestBeginBlock) (*types.ResponseBeginBlock, error) {
   363  	reqres := cli.queueRequest(types.ToRequestBeginBlock(req))
   364  	if err := cli.FlushSync(); err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	return reqres.Response.GetBeginBlock(), cli.Error()
   369  }
   370  
   371  func (cli *socketClient) EndBlockSync(req types.RequestEndBlock) (*types.ResponseEndBlock, error) {
   372  	reqres := cli.queueRequest(types.ToRequestEndBlock(req))
   373  	if err := cli.FlushSync(); err != nil {
   374  		return nil, err
   375  	}
   376  
   377  	return reqres.Response.GetEndBlock(), cli.Error()
   378  }
   379  
   380  func (cli *socketClient) ListSnapshotsSync(req types.RequestListSnapshots) (*types.ResponseListSnapshots, error) {
   381  	reqres := cli.queueRequest(types.ToRequestListSnapshots(req))
   382  	if err := cli.FlushSync(); err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	return reqres.Response.GetListSnapshots(), cli.Error()
   387  }
   388  
   389  func (cli *socketClient) OfferSnapshotSync(req types.RequestOfferSnapshot) (*types.ResponseOfferSnapshot, error) {
   390  	reqres := cli.queueRequest(types.ToRequestOfferSnapshot(req))
   391  	if err := cli.FlushSync(); err != nil {
   392  		return nil, err
   393  	}
   394  
   395  	return reqres.Response.GetOfferSnapshot(), cli.Error()
   396  }
   397  
   398  func (cli *socketClient) LoadSnapshotChunkSync(
   399  	req types.RequestLoadSnapshotChunk) (*types.ResponseLoadSnapshotChunk, error) {
   400  	reqres := cli.queueRequest(types.ToRequestLoadSnapshotChunk(req))
   401  	if err := cli.FlushSync(); err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	return reqres.Response.GetLoadSnapshotChunk(), cli.Error()
   406  }
   407  
   408  func (cli *socketClient) ApplySnapshotChunkSync(
   409  	req types.RequestApplySnapshotChunk) (*types.ResponseApplySnapshotChunk, error) {
   410  	reqres := cli.queueRequest(types.ToRequestApplySnapshotChunk(req))
   411  	if err := cli.FlushSync(); err != nil {
   412  		return nil, err
   413  	}
   414  	return reqres.Response.GetApplySnapshotChunk(), cli.Error()
   415  }
   416  
   417  func (cli *socketClient) PrepareProposalSync(
   418  	req types.RequestPrepareProposal,
   419  ) (*types.ResponsePrepareProposal, error) {
   420  
   421  	reqres := cli.queueRequest(types.ToRequestPrepareProposal(req))
   422  	if err := cli.FlushSync(); err != nil {
   423  		return nil, err
   424  	}
   425  	return reqres.Response.GetPrepareProposal(), nil
   426  }
   427  
   428  func (cli *socketClient) ProcessProposalSync(
   429  	req types.RequestProcessProposal,
   430  ) (*types.ResponseProcessProposal, error) {
   431  
   432  	reqres := cli.queueRequest(types.ToRequestProcessProposal(req))
   433  	if err := cli.FlushSync(); err != nil {
   434  		return nil, err
   435  	}
   436  	return reqres.Response.GetProcessProposal(), nil
   437  }
   438  
   439  //----------------------------------------
   440  
   441  func (cli *socketClient) queueRequest(req *types.Request) *ReqRes {
   442  	reqres := NewReqRes(req)
   443  
   444  	// TODO: set cli.err if reqQueue times out
   445  	cli.reqQueue <- reqres
   446  
   447  	// Maybe auto-flush, or unset auto-flush
   448  	switch req.Value.(type) {
   449  	case *types.Request_Flush:
   450  		cli.flushTimer.Unset()
   451  	default:
   452  		cli.flushTimer.Set()
   453  	}
   454  
   455  	return reqres
   456  }
   457  
   458  func (cli *socketClient) flushQueue() {
   459  	cli.mtx.Lock()
   460  	defer cli.mtx.Unlock()
   461  
   462  	// mark all in-flight messages as resolved (they will get cli.Error())
   463  	for req := cli.reqSent.Front(); req != nil; req = req.Next() {
   464  		reqres := req.Value.(*ReqRes)
   465  		reqres.Done()
   466  	}
   467  
   468  	// mark all queued messages as resolved
   469  LOOP:
   470  	for {
   471  		select {
   472  		case reqres := <-cli.reqQueue:
   473  			reqres.Done()
   474  		default:
   475  			break LOOP
   476  		}
   477  	}
   478  }
   479  
   480  //----------------------------------------
   481  
   482  func resMatchesReq(req *types.Request, res *types.Response) (ok bool) {
   483  	switch req.Value.(type) {
   484  	case *types.Request_Echo:
   485  		_, ok = res.Value.(*types.Response_Echo)
   486  	case *types.Request_Flush:
   487  		_, ok = res.Value.(*types.Response_Flush)
   488  	case *types.Request_Info:
   489  		_, ok = res.Value.(*types.Response_Info)
   490  	case *types.Request_DeliverTx:
   491  		_, ok = res.Value.(*types.Response_DeliverTx)
   492  	case *types.Request_CheckTx:
   493  		_, ok = res.Value.(*types.Response_CheckTx)
   494  	case *types.Request_Commit:
   495  		_, ok = res.Value.(*types.Response_Commit)
   496  	case *types.Request_Query:
   497  		_, ok = res.Value.(*types.Response_Query)
   498  	case *types.Request_InitChain:
   499  		_, ok = res.Value.(*types.Response_InitChain)
   500  	case *types.Request_BeginBlock:
   501  		_, ok = res.Value.(*types.Response_BeginBlock)
   502  	case *types.Request_EndBlock:
   503  		_, ok = res.Value.(*types.Response_EndBlock)
   504  	case *types.Request_ApplySnapshotChunk:
   505  		_, ok = res.Value.(*types.Response_ApplySnapshotChunk)
   506  	case *types.Request_LoadSnapshotChunk:
   507  		_, ok = res.Value.(*types.Response_LoadSnapshotChunk)
   508  	case *types.Request_ListSnapshots:
   509  		_, ok = res.Value.(*types.Response_ListSnapshots)
   510  	case *types.Request_OfferSnapshot:
   511  		_, ok = res.Value.(*types.Response_OfferSnapshot)
   512  	case *types.Request_PrepareProposal:
   513  		_, ok = res.Value.(*types.Response_PrepareProposal)
   514  	case *types.Request_ProcessProposal:
   515  		_, ok = res.Value.(*types.Response_ProcessProposal)
   516  	}
   517  	return ok
   518  }
   519  
   520  func (cli *socketClient) stopForError(err error) {
   521  	if !cli.IsRunning() {
   522  		return
   523  	}
   524  
   525  	cli.mtx.Lock()
   526  	if cli.err == nil {
   527  		cli.err = err
   528  	}
   529  	cli.mtx.Unlock()
   530  
   531  	cli.Logger.Error(fmt.Sprintf("Stopping abci.socketClient for error: %v", err.Error()))
   532  	if err := cli.Stop(); err != nil {
   533  		cli.Logger.Error("Error stopping abci.socketClient", "err", err)
   534  	}
   535  }