github.com/matrixorigin/matrixone@v0.7.0/pkg/common/morpc/server.go (about) 1 // Copyright 2021 - 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package morpc 16 17 import ( 18 "context" 19 "fmt" 20 "sync" 21 "time" 22 23 "github.com/fagongzi/goetty/v2" 24 "github.com/matrixorigin/matrixone/pkg/common/moerr" 25 "github.com/matrixorigin/matrixone/pkg/common/stopper" 26 "github.com/matrixorigin/matrixone/pkg/logutil" 27 "go.uber.org/zap" 28 ) 29 30 // WithServerLogger set rpc server logger 31 func WithServerLogger(logger *zap.Logger) ServerOption { 32 return func(rs *server) { 33 rs.logger = logger 34 } 35 } 36 37 // WithServerSessionBufferSize set the buffer size of the write response chan. 38 // Default is 16. 39 func WithServerSessionBufferSize(size int) ServerOption { 40 return func(s *server) { 41 s.options.bufferSize = size 42 } 43 } 44 45 // WithServerWriteFilter set write filter func. Input ready to send Messages, output 46 // is really need to be send Messages. 47 func WithServerWriteFilter(filter func(Message) bool) ServerOption { 48 return func(s *server) { 49 s.options.filter = filter 50 } 51 } 52 53 // WithServerGoettyOptions set write filter func. Input ready to send Messages, output 54 // is really need to be send Messages. 55 func WithServerGoettyOptions(options ...goetty.Option) ServerOption { 56 return func(s *server) { 57 s.options.goettyOptions = options 58 } 59 } 60 61 // WithServerBatchSendSize set the maximum number of messages to be sent together 62 // at each batch. Default is 8. 63 func WithServerBatchSendSize(size int) ServerOption { 64 return func(s *server) { 65 s.options.batchSendSize = size 66 } 67 } 68 69 // WithServerDisableAutoCancelContext disable automatic cancel messaging for the context. 70 // The server will receive RPC messages from the client, each message comes with a Context, 71 // and morpc will call the handler to process it, and when the handler returns, the Context 72 // will be auto cancel the context. But in some scenarios, the handler is asynchronous, 73 // so morpc can't directly cancel the context after the handler returns, otherwise many strange 74 // problems will occur. 75 func WithServerDisableAutoCancelContext() ServerOption { 76 return func(s *server) { 77 s.options.disableAutoCancelContext = true 78 } 79 } 80 81 type server struct { 82 name string 83 address string 84 logger *zap.Logger 85 codec Codec 86 application goetty.NetApplication 87 stopper *stopper.Stopper 88 handler func(ctx context.Context, request Message, sequence uint64, cs ClientSession) error 89 sessions *sync.Map // session-id => *clientSession 90 options struct { 91 goettyOptions []goetty.Option 92 bufferSize int 93 batchSendSize int 94 filter func(Message) bool 95 disableAutoCancelContext bool 96 } 97 pool struct { 98 futures *sync.Pool 99 } 100 } 101 102 // NewRPCServer create rpc server with options. After the rpc server starts, one link corresponds to two 103 // goroutines, one read and one write. All messages to be written are first written to a buffer chan and 104 // sent to the client by the write goroutine. 105 func NewRPCServer(name, address string, codec Codec, options ...ServerOption) (RPCServer, error) { 106 s := &server{ 107 name: name, 108 address: address, 109 codec: codec, 110 stopper: stopper.NewStopper(fmt.Sprintf("rpc-server-%s", name)), 111 sessions: &sync.Map{}, 112 } 113 for _, opt := range options { 114 opt(s) 115 } 116 s.adjust() 117 118 s.options.goettyOptions = append(s.options.goettyOptions, 119 goetty.WithSessionCodec(codec), 120 goetty.WithSessionLogger(s.logger)) 121 122 app, err := goetty.NewApplication( 123 s.address, 124 s.onMessage, 125 goetty.WithAppLogger(s.logger), 126 goetty.WithAppSessionOptions(s.options.goettyOptions...), 127 ) 128 if err != nil { 129 s.logger.Error("create rpc server failed", 130 zap.Error(err)) 131 return nil, err 132 } 133 s.application = app 134 s.pool.futures = &sync.Pool{ 135 New: func() interface{} { 136 return newFuture(s.releaseFuture) 137 }, 138 } 139 return s, nil 140 } 141 142 func (s *server) Start() error { 143 err := s.application.Start() 144 if err != nil { 145 s.logger.Fatal("start rpcserver failed", 146 zap.Error(err)) 147 return err 148 } 149 return nil 150 } 151 152 func (s *server) Close() error { 153 s.stopper.Stop() 154 err := s.application.Stop() 155 if err != nil { 156 s.logger.Error("stop rpcserver failed", 157 zap.Error(err)) 158 } 159 160 return err 161 } 162 163 func (s *server) RegisterRequestHandler(handler func(ctx context.Context, request Message, sequence uint64, cs ClientSession) error) { 164 s.handler = handler 165 } 166 167 func (s *server) adjust() { 168 s.logger = logutil.Adjust(s.logger).With(zap.String("name", s.name)) 169 if s.options.batchSendSize == 0 { 170 s.options.batchSendSize = 8 171 } 172 if s.options.bufferSize == 0 { 173 s.options.bufferSize = 16 174 } 175 if s.options.filter == nil { 176 s.options.filter = func(messages Message) bool { 177 return true 178 } 179 } 180 } 181 182 func (s *server) onMessage(rs goetty.IOSession, value any, sequence uint64) error { 183 cs, err := s.getSession(rs) 184 if err != nil { 185 return err 186 } 187 request := value.(RPCMessage) 188 if ce := s.logger.Check(zap.DebugLevel, "received request"); ce != nil { 189 ce.Write(zap.Uint64("sequence", sequence), 190 zap.String("client", rs.RemoteAddress()), 191 zap.Uint64("request-id", request.Message.GetID()), 192 zap.String("request", request.Message.DebugString())) 193 } 194 195 // Can't be sure that the Context is properly consumed if disableAutoCancelContext is set to 196 // true. So we use the pessimistic wait for the context to time out automatically be canceled 197 // behavior here, which may cause some resources to be released more slowly. 198 // FIXME: Use the CancelFunc pass to let the handler decide to cancel itself 199 if !s.options.disableAutoCancelContext && request.cancel != nil { 200 defer request.cancel() 201 } 202 // get requestID here to avoid data race, because the request maybe released in handler 203 requestID := request.Message.GetID() 204 205 if request.stream && 206 !cs.validateStreamRequest(requestID, request.streamSequence) { 207 s.logger.Error("failed to handle stream request", 208 zap.Uint32("last-sequence", cs.receivedStreamSequences[requestID]), 209 zap.Uint32("current-sequence", request.streamSequence), 210 zap.String("client", rs.RemoteAddress())) 211 cs.cancelWrite() 212 return moerr.NewStreamClosedNoCtx() 213 } 214 215 // handle internal message 216 if request.internal { 217 if m, ok := request.Message.(*flagOnlyMessage); ok { 218 switch m.flag { 219 case flagPing: 220 return cs.Write(request.Ctx, &flagOnlyMessage{flag: flagPong, id: m.id}) 221 default: 222 panic(fmt.Sprintf("invalid internal message, flag %d", m.flag)) 223 } 224 } 225 } 226 227 if err := s.handler(request.Ctx, request.Message, sequence, cs); err != nil { 228 s.logger.Error("handle request failed", 229 zap.Uint64("sequence", sequence), 230 zap.String("client", rs.RemoteAddress()), 231 zap.Error(err)) 232 cs.cancelWrite() 233 return err 234 } 235 236 if ce := s.logger.Check(zap.DebugLevel, "handle request completed"); ce != nil { 237 ce.Write(zap.Uint64("sequence", sequence), 238 zap.String("client", rs.RemoteAddress()), 239 zap.Uint64("request-id", requestID)) 240 } 241 return nil 242 } 243 244 func (s *server) startWriteLoop(cs *clientSession) error { 245 return s.stopper.RunTask(func(ctx context.Context) { 246 defer s.closeClientSession(cs) 247 248 responses := make([]*Future, 0, s.options.batchSendSize) 249 fetch := func() { 250 for i := 0; i < len(responses); i++ { 251 responses[i] = nil 252 } 253 responses = responses[:0] 254 255 for i := 0; i < s.options.batchSendSize; i++ { 256 if len(responses) == 0 { 257 select { 258 case <-ctx.Done(): 259 responses = nil 260 return 261 case <-cs.ctx.Done(): 262 responses = nil 263 return 264 case resp := <-cs.c: 265 responses = append(responses, resp) 266 } 267 } else { 268 select { 269 case <-ctx.Done(): 270 responses = nil 271 return 272 case <-cs.ctx.Done(): 273 responses = nil 274 return 275 case resp := <-cs.c: 276 responses = append(responses, resp) 277 default: 278 return 279 } 280 } 281 } 282 } 283 284 for { 285 select { 286 case <-ctx.Done(): 287 return 288 case <-cs.ctx.Done(): 289 return 290 default: 291 fetch() 292 293 if len(responses) > 0 { 294 var fields []zap.Field 295 ce := s.logger.Check(zap.DebugLevel, "write responses") 296 if ce != nil { 297 fields = append(fields, zap.String("client", cs.conn.RemoteAddress())) 298 } 299 300 written := 0 301 timeout := time.Duration(0) 302 for _, f := range responses { 303 if !s.options.filter(f.send.Message) { 304 f.messageSended(messageSkipped) 305 continue 306 } 307 308 if f.send.Timeout() { 309 f.messageSended(f.send.Ctx.Err()) 310 continue 311 } 312 313 v, err := f.send.GetTimeoutFromContext() 314 if err != nil { 315 f.messageSended(err) 316 continue 317 } 318 319 timeout += v 320 // Record the information of some responses in advance, because after flush, 321 // these responses will be released, thus avoiding causing data race. 322 if ce != nil { 323 fields = append(fields, zap.Uint64("request-id", 324 f.send.Message.GetID())) 325 fields = append(fields, zap.String("response", 326 f.send.Message.DebugString())) 327 } 328 if err := cs.conn.Write(f.send, goetty.WriteOptions{}); err != nil { 329 s.logger.Error("write response failed", 330 zap.Uint64("request-id", f.send.Message.GetID()), 331 zap.Error(err)) 332 f.messageSended(err) 333 return 334 } 335 written++ 336 } 337 338 if written > 0 { 339 err := cs.conn.Flush(timeout) 340 if err != nil { 341 if ce != nil { 342 fields = append(fields, zap.Error(err)) 343 } 344 for _, f := range responses { 345 if !s.options.filter(f.send.Message) { 346 id := f.getSendMessageID() 347 s.logger.Error("write response failed", 348 zap.Uint64("request-id", id), 349 zap.Error(err)) 350 f.messageSended(err) 351 } 352 } 353 } 354 if ce != nil { 355 ce.Write(fields...) 356 } 357 } 358 359 for _, f := range responses { 360 f.messageSended(nil) 361 } 362 } 363 } 364 } 365 }) 366 } 367 368 func (s *server) closeClientSession(cs *clientSession) { 369 s.sessions.Delete(cs.conn.ID()) 370 if err := cs.Close(); err != nil { 371 s.logger.Error("close client session failed", 372 zap.Error(err)) 373 } 374 } 375 376 func (s *server) getSession(rs goetty.IOSession) (*clientSession, error) { 377 if v, ok := s.sessions.Load(rs.ID()); ok { 378 return v.(*clientSession), nil 379 } 380 381 cs := newClientSession(rs, s.codec, s.newFuture) 382 v, loaded := s.sessions.LoadOrStore(rs.ID(), cs) 383 if loaded { 384 close(cs.c) 385 return v.(*clientSession), nil 386 } 387 388 rs.Ref() 389 if err := s.startWriteLoop(cs); err != nil { 390 s.closeClientSession(cs) 391 return nil, err 392 } 393 return cs, nil 394 } 395 396 func (s *server) releaseFuture(f *Future) { 397 f.reset() 398 s.pool.futures.Put(f) 399 } 400 401 func (s *server) newFuture() *Future { 402 return s.pool.futures.Get().(*Future) 403 } 404 405 type clientSession struct { 406 codec Codec 407 conn goetty.IOSession 408 c chan *Future 409 newFutureFunc func() *Future 410 // streaming id -> last received sequence, no concurrent, access in io goroutine 411 receivedStreamSequences map[uint64]uint32 412 // streaming id -> last sent sequence, multi-stream access in multi-goroutines if 413 // the tcp connection is shared. But no concurrent in one stream. 414 sentStreamSequences sync.Map 415 cancel context.CancelFunc 416 ctx context.Context 417 mu struct { 418 sync.RWMutex 419 closed bool 420 } 421 } 422 423 func newClientSession( 424 conn goetty.IOSession, 425 codec Codec, 426 newFutureFunc func() *Future) *clientSession { 427 ctx, cancel := context.WithCancel(context.Background()) 428 return &clientSession{ 429 codec: codec, 430 c: make(chan *Future, 32), 431 receivedStreamSequences: make(map[uint64]uint32), 432 conn: conn, 433 ctx: ctx, 434 cancel: cancel, 435 newFutureFunc: newFutureFunc, 436 } 437 } 438 439 func (cs *clientSession) Close() error { 440 cs.mu.Lock() 441 defer cs.mu.Unlock() 442 if cs.mu.closed { 443 return nil 444 } 445 446 close(cs.c) 447 cs.mu.closed = true 448 return cs.conn.Close() 449 } 450 451 func (cs *clientSession) Write(ctx context.Context, response Message) error { 452 if err := cs.codec.Valid(response); err != nil { 453 return err 454 } 455 456 cs.mu.RLock() 457 defer cs.mu.RUnlock() 458 459 if cs.mu.closed { 460 return moerr.NewClientClosedNoCtx() 461 } 462 463 msg := RPCMessage{Ctx: ctx, Message: response} 464 id := response.GetID() 465 if v, ok := cs.sentStreamSequences.Load(id); ok { 466 seq := v.(uint32) + 1 467 cs.sentStreamSequences.Store(id, seq) 468 msg.stream = true 469 msg.streamSequence = seq 470 } 471 472 f := cs.newFutureFunc() 473 f.ref() 474 f.init(msg) 475 defer f.Close() 476 477 cs.c <- f 478 // stream only wait send completed 479 return f.waitSendCompleted() 480 } 481 482 func (cs *clientSession) cancelWrite() { 483 cs.cancel() 484 } 485 486 func (cs *clientSession) validateStreamRequest(id uint64, sequence uint32) bool { 487 expectSequence := cs.receivedStreamSequences[id] + 1 488 if sequence != expectSequence { 489 return false 490 } 491 cs.receivedStreamSequences[id] = sequence 492 if sequence == 1 { 493 cs.sentStreamSequences.Store(id, uint32(0)) 494 } 495 return true 496 }