github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/abci/server/socket_server.go (about) 1 package server 2 3 import ( 4 "bufio" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "runtime" 11 "sync" 12 13 "github.com/ari-anchor/sei-tendermint/abci/types" 14 "github.com/ari-anchor/sei-tendermint/libs/log" 15 tmnet "github.com/ari-anchor/sei-tendermint/libs/net" 16 "github.com/ari-anchor/sei-tendermint/libs/service" 17 ) 18 19 // var maxNumberConnections = 2 20 21 type SocketServer struct { 22 service.BaseService 23 logger log.Logger 24 25 proto string 26 addr string 27 listener net.Listener 28 29 connsMtx sync.Mutex 30 connsClose map[int]func() 31 nextConnID int 32 33 app types.Application 34 } 35 36 func NewSocketServer(logger log.Logger, protoAddr string, app types.Application) service.Service { 37 proto, addr := tmnet.ProtocolAndAddress(protoAddr) 38 s := &SocketServer{ 39 logger: logger, 40 proto: proto, 41 addr: addr, 42 listener: nil, 43 app: app, 44 connsClose: make(map[int]func()), 45 } 46 s.BaseService = *service.NewBaseService(logger, "ABCIServer", s) 47 return s 48 } 49 50 func (s *SocketServer) OnStart(ctx context.Context) error { 51 ln, err := net.Listen(s.proto, s.addr) 52 if err != nil { 53 return err 54 } 55 56 s.listener = ln 57 go s.acceptConnectionsRoutine(ctx) 58 59 return nil 60 } 61 62 func (s *SocketServer) OnStop() { 63 if err := s.listener.Close(); err != nil { 64 s.logger.Error("error closing listener", "err", err) 65 } 66 67 s.connsMtx.Lock() 68 defer s.connsMtx.Unlock() 69 70 for _, closer := range s.connsClose { 71 closer() 72 } 73 } 74 75 func (s *SocketServer) addConn(closer func()) int { 76 s.connsMtx.Lock() 77 defer s.connsMtx.Unlock() 78 79 connID := s.nextConnID 80 s.nextConnID++ 81 s.connsClose[connID] = closer 82 return connID 83 } 84 85 // deletes conn even if close errs 86 func (s *SocketServer) rmConn(connID int) { 87 s.connsMtx.Lock() 88 defer s.connsMtx.Unlock() 89 if closer, ok := s.connsClose[connID]; ok { 90 closer() 91 delete(s.connsClose, connID) 92 } 93 } 94 95 func (s *SocketServer) acceptConnectionsRoutine(ctx context.Context) { 96 for { 97 if ctx.Err() != nil { 98 return 99 } 100 101 // Accept a connection 102 s.logger.Info("Waiting for new connection...") 103 conn, err := s.listener.Accept() 104 if err != nil { 105 if !s.IsRunning() { 106 return // Ignore error from listener closing. 107 } 108 s.logger.Error("Failed to accept connection", "err", err) 109 continue 110 } 111 112 cctx, ccancel := context.WithCancel(ctx) 113 connID := s.addConn(ccancel) 114 115 s.logger.Info("Accepted a new connection", "id", connID) 116 117 responses := make(chan *types.Response, 1000) // A channel to buffer responses 118 119 once := &sync.Once{} 120 closer := func(err error) { 121 ccancel() 122 once.Do(func() { 123 if cerr := conn.Close(); err != nil { 124 s.logger.Error("error closing connection", 125 "id", connID, 126 "close_err", cerr, 127 "err", err) 128 } 129 s.rmConn(connID) 130 131 switch { 132 case errors.Is(err, context.Canceled): 133 s.logger.Error("Connection terminated", 134 "id", connID, 135 "err", err) 136 case errors.Is(err, context.DeadlineExceeded): 137 s.logger.Error("Connection encountered timeout", 138 "id", connID, 139 "err", err) 140 case errors.Is(err, io.EOF): 141 s.logger.Error("Connection was closed by client", 142 "id", connID) 143 case err != nil: 144 s.logger.Error("Connection error", 145 "id", connID, 146 "err", err) 147 default: 148 s.logger.Error("Connection was closed", 149 "id", connID) 150 } 151 }) 152 } 153 154 // Read requests from conn and deal with them 155 go s.handleRequests(cctx, closer, conn, responses) 156 // Pull responses from 'responses' and write them to conn. 157 go s.handleResponses(cctx, closer, conn, responses) 158 } 159 } 160 161 // Read requests from conn and deal with them 162 func (s *SocketServer) handleRequests(ctx context.Context, closer func(error), conn io.Reader, responses chan<- *types.Response) { 163 var bufReader = bufio.NewReader(conn) 164 165 defer func() { 166 // make sure to recover from any app-related panics to allow proper socket cleanup 167 if r := recover(); r != nil { 168 const size = 64 << 10 169 buf := make([]byte, size) 170 buf = buf[:runtime.Stack(buf, false)] 171 closer(fmt.Errorf("recovered from panic: %v\n%s", r, buf)) 172 } 173 }() 174 175 for { 176 req := &types.Request{} 177 if err := types.ReadMessage(bufReader, req); err != nil { 178 closer(fmt.Errorf("error reading message: %w", err)) 179 return 180 } 181 182 resp, err := s.processRequest(ctx, req) 183 if err != nil { 184 closer(err) 185 return 186 } 187 188 select { 189 case <-ctx.Done(): 190 closer(ctx.Err()) 191 return 192 case responses <- resp: 193 } 194 } 195 } 196 197 func (s *SocketServer) processRequest(ctx context.Context, req *types.Request) (*types.Response, error) { 198 switch r := req.Value.(type) { 199 case *types.Request_Echo: 200 return types.ToResponseEcho(r.Echo.Message), nil 201 case *types.Request_Flush: 202 return types.ToResponseFlush(), nil 203 case *types.Request_Info: 204 res, err := s.app.Info(ctx, r.Info) 205 if err != nil { 206 return nil, err 207 } 208 209 return types.ToResponseInfo(res), nil 210 case *types.Request_CheckTx: 211 res, err := s.app.CheckTx(ctx, r.CheckTx) 212 if err != nil { 213 return nil, err 214 } 215 return types.ToResponseCheckTx(res), nil 216 case *types.Request_Commit: 217 res, err := s.app.Commit(ctx) 218 if err != nil { 219 return nil, err 220 } 221 return types.ToResponseCommit(res), nil 222 case *types.Request_Query: 223 res, err := s.app.Query(ctx, r.Query) 224 if err != nil { 225 return nil, err 226 } 227 return types.ToResponseQuery(res), nil 228 case *types.Request_InitChain: 229 res, err := s.app.InitChain(ctx, r.InitChain) 230 if err != nil { 231 return nil, err 232 } 233 return types.ToResponseInitChain(res), nil 234 case *types.Request_ListSnapshots: 235 res, err := s.app.ListSnapshots(ctx, r.ListSnapshots) 236 if err != nil { 237 return nil, err 238 } 239 return types.ToResponseListSnapshots(res), nil 240 case *types.Request_OfferSnapshot: 241 res, err := s.app.OfferSnapshot(ctx, r.OfferSnapshot) 242 if err != nil { 243 return nil, err 244 } 245 return types.ToResponseOfferSnapshot(res), nil 246 case *types.Request_PrepareProposal: 247 res, err := s.app.PrepareProposal(ctx, r.PrepareProposal) 248 if err != nil { 249 return nil, err 250 } 251 return types.ToResponsePrepareProposal(res), nil 252 case *types.Request_ProcessProposal: 253 res, err := s.app.ProcessProposal(ctx, r.ProcessProposal) 254 if err != nil { 255 return nil, err 256 } 257 return types.ToResponseProcessProposal(res), nil 258 case *types.Request_LoadSnapshotChunk: 259 res, err := s.app.LoadSnapshotChunk(ctx, r.LoadSnapshotChunk) 260 if err != nil { 261 return nil, err 262 } 263 return types.ToResponseLoadSnapshotChunk(res), nil 264 case *types.Request_ApplySnapshotChunk: 265 res, err := s.app.ApplySnapshotChunk(ctx, r.ApplySnapshotChunk) 266 if err != nil { 267 return nil, err 268 } 269 return types.ToResponseApplySnapshotChunk(res), nil 270 case *types.Request_ExtendVote: 271 res, err := s.app.ExtendVote(ctx, r.ExtendVote) 272 if err != nil { 273 return nil, err 274 } 275 return types.ToResponseExtendVote(res), nil 276 case *types.Request_VerifyVoteExtension: 277 res, err := s.app.VerifyVoteExtension(ctx, r.VerifyVoteExtension) 278 if err != nil { 279 return nil, err 280 } 281 return types.ToResponseVerifyVoteExtension(res), nil 282 case *types.Request_FinalizeBlock: 283 res, err := s.app.FinalizeBlock(ctx, r.FinalizeBlock) 284 if err != nil { 285 return nil, err 286 } 287 return types.ToResponseFinalizeBlock(res), nil 288 default: 289 return types.ToResponseException("Unknown request"), errors.New("unknown request type") 290 } 291 } 292 293 // Pull responses from 'responses' and write them to conn. 294 func (s *SocketServer) handleResponses( 295 ctx context.Context, 296 closer func(error), 297 conn io.Writer, 298 responses <-chan *types.Response, 299 ) { 300 bw := bufio.NewWriter(conn) 301 for { 302 select { 303 case <-ctx.Done(): 304 closer(ctx.Err()) 305 return 306 case res := <-responses: 307 if err := types.WriteMessage(res, bw); err != nil { 308 closer(fmt.Errorf("error writing message: %w", err)) 309 return 310 } 311 if err := bw.Flush(); err != nil { 312 closer(fmt.Errorf("error writing message: %w", err)) 313 return 314 } 315 } 316 } 317 }