github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/abci/server/socket_server.go (about)

     1  package server
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"runtime"
    11  	"sync"
    12  
    13  	"github.com/ari-anchor/sei-tendermint/abci/types"
    14  	"github.com/ari-anchor/sei-tendermint/libs/log"
    15  	tmnet "github.com/ari-anchor/sei-tendermint/libs/net"
    16  	"github.com/ari-anchor/sei-tendermint/libs/service"
    17  )
    18  
    19  // var maxNumberConnections = 2
    20  
    21  type SocketServer struct {
    22  	service.BaseService
    23  	logger log.Logger
    24  
    25  	proto    string
    26  	addr     string
    27  	listener net.Listener
    28  
    29  	connsMtx   sync.Mutex
    30  	connsClose map[int]func()
    31  	nextConnID int
    32  
    33  	app types.Application
    34  }
    35  
    36  func NewSocketServer(logger log.Logger, protoAddr string, app types.Application) service.Service {
    37  	proto, addr := tmnet.ProtocolAndAddress(protoAddr)
    38  	s := &SocketServer{
    39  		logger:     logger,
    40  		proto:      proto,
    41  		addr:       addr,
    42  		listener:   nil,
    43  		app:        app,
    44  		connsClose: make(map[int]func()),
    45  	}
    46  	s.BaseService = *service.NewBaseService(logger, "ABCIServer", s)
    47  	return s
    48  }
    49  
    50  func (s *SocketServer) OnStart(ctx context.Context) error {
    51  	ln, err := net.Listen(s.proto, s.addr)
    52  	if err != nil {
    53  		return err
    54  	}
    55  
    56  	s.listener = ln
    57  	go s.acceptConnectionsRoutine(ctx)
    58  
    59  	return nil
    60  }
    61  
    62  func (s *SocketServer) OnStop() {
    63  	if err := s.listener.Close(); err != nil {
    64  		s.logger.Error("error closing listener", "err", err)
    65  	}
    66  
    67  	s.connsMtx.Lock()
    68  	defer s.connsMtx.Unlock()
    69  
    70  	for _, closer := range s.connsClose {
    71  		closer()
    72  	}
    73  }
    74  
    75  func (s *SocketServer) addConn(closer func()) int {
    76  	s.connsMtx.Lock()
    77  	defer s.connsMtx.Unlock()
    78  
    79  	connID := s.nextConnID
    80  	s.nextConnID++
    81  	s.connsClose[connID] = closer
    82  	return connID
    83  }
    84  
    85  // deletes conn even if close errs
    86  func (s *SocketServer) rmConn(connID int) {
    87  	s.connsMtx.Lock()
    88  	defer s.connsMtx.Unlock()
    89  	if closer, ok := s.connsClose[connID]; ok {
    90  		closer()
    91  		delete(s.connsClose, connID)
    92  	}
    93  }
    94  
    95  func (s *SocketServer) acceptConnectionsRoutine(ctx context.Context) {
    96  	for {
    97  		if ctx.Err() != nil {
    98  			return
    99  		}
   100  
   101  		// Accept a connection
   102  		s.logger.Info("Waiting for new connection...")
   103  		conn, err := s.listener.Accept()
   104  		if err != nil {
   105  			if !s.IsRunning() {
   106  				return // Ignore error from listener closing.
   107  			}
   108  			s.logger.Error("Failed to accept connection", "err", err)
   109  			continue
   110  		}
   111  
   112  		cctx, ccancel := context.WithCancel(ctx)
   113  		connID := s.addConn(ccancel)
   114  
   115  		s.logger.Info("Accepted a new connection", "id", connID)
   116  
   117  		responses := make(chan *types.Response, 1000) // A channel to buffer responses
   118  
   119  		once := &sync.Once{}
   120  		closer := func(err error) {
   121  			ccancel()
   122  			once.Do(func() {
   123  				if cerr := conn.Close(); err != nil {
   124  					s.logger.Error("error closing connection",
   125  						"id", connID,
   126  						"close_err", cerr,
   127  						"err", err)
   128  				}
   129  				s.rmConn(connID)
   130  
   131  				switch {
   132  				case errors.Is(err, context.Canceled):
   133  					s.logger.Error("Connection terminated",
   134  						"id", connID,
   135  						"err", err)
   136  				case errors.Is(err, context.DeadlineExceeded):
   137  					s.logger.Error("Connection encountered timeout",
   138  						"id", connID,
   139  						"err", err)
   140  				case errors.Is(err, io.EOF):
   141  					s.logger.Error("Connection was closed by client",
   142  						"id", connID)
   143  				case err != nil:
   144  					s.logger.Error("Connection error",
   145  						"id", connID,
   146  						"err", err)
   147  				default:
   148  					s.logger.Error("Connection was closed",
   149  						"id", connID)
   150  				}
   151  			})
   152  		}
   153  
   154  		// Read requests from conn and deal with them
   155  		go s.handleRequests(cctx, closer, conn, responses)
   156  		// Pull responses from 'responses' and write them to conn.
   157  		go s.handleResponses(cctx, closer, conn, responses)
   158  	}
   159  }
   160  
   161  // Read requests from conn and deal with them
   162  func (s *SocketServer) handleRequests(ctx context.Context, closer func(error), conn io.Reader, responses chan<- *types.Response) {
   163  	var bufReader = bufio.NewReader(conn)
   164  
   165  	defer func() {
   166  		// make sure to recover from any app-related panics to allow proper socket cleanup
   167  		if r := recover(); r != nil {
   168  			const size = 64 << 10
   169  			buf := make([]byte, size)
   170  			buf = buf[:runtime.Stack(buf, false)]
   171  			closer(fmt.Errorf("recovered from panic: %v\n%s", r, buf))
   172  		}
   173  	}()
   174  
   175  	for {
   176  		req := &types.Request{}
   177  		if err := types.ReadMessage(bufReader, req); err != nil {
   178  			closer(fmt.Errorf("error reading message: %w", err))
   179  			return
   180  		}
   181  
   182  		resp, err := s.processRequest(ctx, req)
   183  		if err != nil {
   184  			closer(err)
   185  			return
   186  		}
   187  
   188  		select {
   189  		case <-ctx.Done():
   190  			closer(ctx.Err())
   191  			return
   192  		case responses <- resp:
   193  		}
   194  	}
   195  }
   196  
   197  func (s *SocketServer) processRequest(ctx context.Context, req *types.Request) (*types.Response, error) {
   198  	switch r := req.Value.(type) {
   199  	case *types.Request_Echo:
   200  		return types.ToResponseEcho(r.Echo.Message), nil
   201  	case *types.Request_Flush:
   202  		return types.ToResponseFlush(), nil
   203  	case *types.Request_Info:
   204  		res, err := s.app.Info(ctx, r.Info)
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  
   209  		return types.ToResponseInfo(res), nil
   210  	case *types.Request_CheckTx:
   211  		res, err := s.app.CheckTx(ctx, r.CheckTx)
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  		return types.ToResponseCheckTx(res), nil
   216  	case *types.Request_Commit:
   217  		res, err := s.app.Commit(ctx)
   218  		if err != nil {
   219  			return nil, err
   220  		}
   221  		return types.ToResponseCommit(res), nil
   222  	case *types.Request_Query:
   223  		res, err := s.app.Query(ctx, r.Query)
   224  		if err != nil {
   225  			return nil, err
   226  		}
   227  		return types.ToResponseQuery(res), nil
   228  	case *types.Request_InitChain:
   229  		res, err := s.app.InitChain(ctx, r.InitChain)
   230  		if err != nil {
   231  			return nil, err
   232  		}
   233  		return types.ToResponseInitChain(res), nil
   234  	case *types.Request_ListSnapshots:
   235  		res, err := s.app.ListSnapshots(ctx, r.ListSnapshots)
   236  		if err != nil {
   237  			return nil, err
   238  		}
   239  		return types.ToResponseListSnapshots(res), nil
   240  	case *types.Request_OfferSnapshot:
   241  		res, err := s.app.OfferSnapshot(ctx, r.OfferSnapshot)
   242  		if err != nil {
   243  			return nil, err
   244  		}
   245  		return types.ToResponseOfferSnapshot(res), nil
   246  	case *types.Request_PrepareProposal:
   247  		res, err := s.app.PrepareProposal(ctx, r.PrepareProposal)
   248  		if err != nil {
   249  			return nil, err
   250  		}
   251  		return types.ToResponsePrepareProposal(res), nil
   252  	case *types.Request_ProcessProposal:
   253  		res, err := s.app.ProcessProposal(ctx, r.ProcessProposal)
   254  		if err != nil {
   255  			return nil, err
   256  		}
   257  		return types.ToResponseProcessProposal(res), nil
   258  	case *types.Request_LoadSnapshotChunk:
   259  		res, err := s.app.LoadSnapshotChunk(ctx, r.LoadSnapshotChunk)
   260  		if err != nil {
   261  			return nil, err
   262  		}
   263  		return types.ToResponseLoadSnapshotChunk(res), nil
   264  	case *types.Request_ApplySnapshotChunk:
   265  		res, err := s.app.ApplySnapshotChunk(ctx, r.ApplySnapshotChunk)
   266  		if err != nil {
   267  			return nil, err
   268  		}
   269  		return types.ToResponseApplySnapshotChunk(res), nil
   270  	case *types.Request_ExtendVote:
   271  		res, err := s.app.ExtendVote(ctx, r.ExtendVote)
   272  		if err != nil {
   273  			return nil, err
   274  		}
   275  		return types.ToResponseExtendVote(res), nil
   276  	case *types.Request_VerifyVoteExtension:
   277  		res, err := s.app.VerifyVoteExtension(ctx, r.VerifyVoteExtension)
   278  		if err != nil {
   279  			return nil, err
   280  		}
   281  		return types.ToResponseVerifyVoteExtension(res), nil
   282  	case *types.Request_FinalizeBlock:
   283  		res, err := s.app.FinalizeBlock(ctx, r.FinalizeBlock)
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  		return types.ToResponseFinalizeBlock(res), nil
   288  	default:
   289  		return types.ToResponseException("Unknown request"), errors.New("unknown request type")
   290  	}
   291  }
   292  
   293  // Pull responses from 'responses' and write them to conn.
   294  func (s *SocketServer) handleResponses(
   295  	ctx context.Context,
   296  	closer func(error),
   297  	conn io.Writer,
   298  	responses <-chan *types.Response,
   299  ) {
   300  	bw := bufio.NewWriter(conn)
   301  	for {
   302  		select {
   303  		case <-ctx.Done():
   304  			closer(ctx.Err())
   305  			return
   306  		case res := <-responses:
   307  			if err := types.WriteMessage(res, bw); err != nil {
   308  				closer(fmt.Errorf("error writing message: %w", err))
   309  				return
   310  			}
   311  			if err := bw.Flush(); err != nil {
   312  				closer(fmt.Errorf("error writing message: %w", err))
   313  				return
   314  			}
   315  		}
   316  	}
   317  }