github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/internal/p2p/transport_mconn.go (about) 1 package p2p 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "math" 9 "net" 10 "strconv" 11 "sync" 12 13 "golang.org/x/net/netutil" 14 15 "github.com/ari-anchor/sei-tendermint/crypto" 16 "github.com/ari-anchor/sei-tendermint/internal/libs/protoio" 17 "github.com/ari-anchor/sei-tendermint/internal/p2p/conn" 18 "github.com/ari-anchor/sei-tendermint/libs/log" 19 p2pproto "github.com/ari-anchor/sei-tendermint/proto/tendermint/p2p" 20 "github.com/ari-anchor/sei-tendermint/types" 21 ) 22 23 const ( 24 MConnProtocol Protocol = "mconn" 25 TCPProtocol Protocol = "tcp" 26 ) 27 28 // MConnTransportOptions sets options for MConnTransport. 29 type MConnTransportOptions struct { 30 // MaxAcceptedConnections is the maximum number of simultaneous accepted 31 // (incoming) connections. Beyond this, new connections will block until 32 // a slot is free. 0 means unlimited. 33 // 34 // FIXME: We may want to replace this with connection accounting in the 35 // Router, since it will need to do e.g. rate limiting and such as well. 36 // But it might also make sense to have per-transport limits. 37 MaxAcceptedConnections uint32 38 } 39 40 // MConnTransport is a Transport implementation using the current multiplexed 41 // Tendermint protocol ("MConn"). 42 type MConnTransport struct { 43 logger log.Logger 44 options MConnTransportOptions 45 mConnConfig conn.MConnConfig 46 channelDescs []*ChannelDescriptor 47 48 closeOnce sync.Once 49 doneCh chan struct{} 50 listener net.Listener 51 } 52 53 // NewMConnTransport sets up a new MConnection transport. This uses the 54 // proprietary Tendermint MConnection protocol, which is implemented as 55 // conn.MConnection. 56 func NewMConnTransport( 57 logger log.Logger, 58 mConnConfig conn.MConnConfig, 59 channelDescs []*ChannelDescriptor, 60 options MConnTransportOptions, 61 ) *MConnTransport { 62 return &MConnTransport{ 63 logger: logger, 64 options: options, 65 mConnConfig: mConnConfig, 66 doneCh: make(chan struct{}), 67 channelDescs: channelDescs, 68 } 69 } 70 71 // String implements Transport. 72 func (m *MConnTransport) String() string { 73 return string(MConnProtocol) 74 } 75 76 // Protocols implements Transport. We support tcp for backwards-compatibility. 77 func (m *MConnTransport) Protocols() []Protocol { 78 return []Protocol{MConnProtocol, TCPProtocol} 79 } 80 81 // Endpoint implements Transport. 82 func (m *MConnTransport) Endpoint() (*Endpoint, error) { 83 if m.listener == nil { 84 return nil, errors.New("listenter not defined") 85 } 86 select { 87 case <-m.doneCh: 88 return nil, errors.New("transport closed") 89 default: 90 } 91 92 endpoint := &Endpoint{ 93 Protocol: MConnProtocol, 94 } 95 if addr, ok := m.listener.Addr().(*net.TCPAddr); ok { 96 endpoint.IP = addr.IP 97 endpoint.Port = uint16(addr.Port) 98 } 99 return endpoint, nil 100 } 101 102 // Listen asynchronously listens for inbound connections on the given endpoint. 103 // It must be called exactly once before calling Accept(), and the caller must 104 // call Close() to shut down the listener. 105 // 106 // FIXME: Listen currently only supports listening on a single endpoint, it 107 // might be useful to support listening on multiple addresses (e.g. IPv4 and 108 // IPv6, or a private and public address) via multiple Listen() calls. 109 func (m *MConnTransport) Listen(endpoint *Endpoint) error { 110 if m.listener != nil { 111 return errors.New("transport is already listening") 112 } 113 if err := m.validateEndpoint(endpoint); err != nil { 114 return err 115 } 116 117 listener, err := net.Listen("tcp", net.JoinHostPort( 118 endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) 119 if err != nil { 120 return err 121 } 122 if m.options.MaxAcceptedConnections > 0 { 123 // FIXME: This will establish the inbound connection but simply hang it 124 // until another connection is released. It would probably be better to 125 // return an error to the remote peer or close the connection. This is 126 // also a DoS vector since the connection will take up kernel resources. 127 // This was just carried over from the legacy P2P stack. 128 listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections)) 129 } 130 m.listener = listener 131 132 return nil 133 } 134 135 // Accept implements Transport. 136 func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { 137 if m.listener == nil { 138 return nil, errors.New("transport is not listening") 139 } 140 141 conCh := make(chan net.Conn) 142 errCh := make(chan error) 143 go func() { 144 tcpConn, err := m.listener.Accept() 145 if err != nil { 146 select { 147 case errCh <- err: 148 case <-ctx.Done(): 149 } 150 } 151 select { 152 case conCh <- tcpConn: 153 case <-ctx.Done(): 154 } 155 }() 156 157 select { 158 case <-ctx.Done(): 159 m.listener.Close() 160 return nil, io.EOF 161 case <-m.doneCh: 162 m.listener.Close() 163 return nil, io.EOF 164 case err := <-errCh: 165 return nil, err 166 case tcpConn := <-conCh: 167 return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil 168 } 169 170 } 171 172 // Dial implements Transport. 173 func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) { 174 if err := m.validateEndpoint(endpoint); err != nil { 175 return nil, err 176 } 177 if endpoint.Port == 0 { 178 endpoint.Port = 26657 179 } 180 181 dialer := net.Dialer{} 182 tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort( 183 endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) 184 if err != nil { 185 select { 186 case <-ctx.Done(): 187 return nil, ctx.Err() 188 default: 189 return nil, err 190 } 191 } 192 193 return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil 194 } 195 196 // Close implements Transport. 197 func (m *MConnTransport) Close() error { 198 var err error 199 m.closeOnce.Do(func() { 200 close(m.doneCh) 201 if m.listener != nil { 202 err = m.listener.Close() 203 } 204 }) 205 return err 206 } 207 208 // SetChannels sets the channel descriptors to be used when 209 // establishing a connection. 210 // 211 // FIXME: To be removed when the legacy p2p stack is removed. Channel 212 // descriptors should be managed by the router. The underlying transport and 213 // connections should be agnostic to everything but the channel ID's which are 214 // initialized in the handshake. 215 func (m *MConnTransport) AddChannelDescriptors(channelDesc []*ChannelDescriptor) { 216 m.channelDescs = append(m.channelDescs, channelDesc...) 217 } 218 219 // validateEndpoint validates an endpoint. 220 func (m *MConnTransport) validateEndpoint(endpoint *Endpoint) error { 221 if err := endpoint.Validate(); err != nil { 222 return err 223 } 224 if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol { 225 return fmt.Errorf("unsupported protocol %q", endpoint.Protocol) 226 } 227 if len(endpoint.IP) == 0 { 228 return errors.New("endpoint has no IP address") 229 } 230 if endpoint.Path != "" { 231 return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path) 232 } 233 return nil 234 } 235 236 // mConnConnection implements Connection for MConnTransport. 237 type mConnConnection struct { 238 logger log.Logger 239 conn net.Conn 240 mConnConfig conn.MConnConfig 241 channelDescs []*ChannelDescriptor 242 receiveCh chan mConnMessage 243 errorCh chan error 244 doneCh chan struct{} 245 closeOnce sync.Once 246 247 mconn *conn.MConnection // set during Handshake() 248 } 249 250 // mConnMessage passes MConnection messages through internal channels. 251 type mConnMessage struct { 252 channelID ChannelID 253 payload []byte 254 } 255 256 // newMConnConnection creates a new mConnConnection. 257 func newMConnConnection( 258 logger log.Logger, 259 conn net.Conn, 260 mConnConfig conn.MConnConfig, 261 channelDescs []*ChannelDescriptor, 262 ) *mConnConnection { 263 return &mConnConnection{ 264 logger: logger, 265 conn: conn, 266 mConnConfig: mConnConfig, 267 channelDescs: channelDescs, 268 receiveCh: make(chan mConnMessage), 269 errorCh: make(chan error, 1), // buffered to avoid onError leak 270 doneCh: make(chan struct{}), 271 } 272 } 273 274 // Handshake implements Connection. 275 func (c *mConnConnection) Handshake( 276 ctx context.Context, 277 nodeInfo types.NodeInfo, 278 privKey crypto.PrivKey, 279 ) (types.NodeInfo, crypto.PubKey, error) { 280 var ( 281 mconn *conn.MConnection 282 peerInfo types.NodeInfo 283 peerKey crypto.PubKey 284 errCh = make(chan error, 1) 285 ) 286 // To handle context cancellation, we need to do the handshake in a 287 // goroutine and abort the blocking network calls by closing the connection 288 // when the context is canceled. 289 go func() { 290 // FIXME: Since the MConnection code panics, we need to recover it and turn it 291 // into an error. We should remove panics instead. 292 defer func() { 293 if r := recover(); r != nil { 294 errCh <- fmt.Errorf("recovered from panic: %v", r) 295 } 296 }() 297 var err error 298 mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) 299 300 select { 301 case errCh <- err: 302 case <-ctx.Done(): 303 } 304 305 }() 306 307 select { 308 case <-ctx.Done(): 309 _ = c.Close() 310 return types.NodeInfo{}, nil, ctx.Err() 311 312 case err := <-errCh: 313 if err != nil { 314 return types.NodeInfo{}, nil, err 315 } 316 c.mconn = mconn 317 if err = c.mconn.Start(ctx); err != nil { 318 return types.NodeInfo{}, nil, err 319 } 320 return peerInfo, peerKey, nil 321 } 322 } 323 324 // handshake is a helper for Handshake, simplifying error handling so we can 325 // keep context handling and panic recovery in Handshake. It returns an 326 // unstarted but handshaked MConnection, to avoid concurrent field writes. 327 func (c *mConnConnection) handshake( 328 ctx context.Context, 329 nodeInfo types.NodeInfo, 330 privKey crypto.PrivKey, 331 ) (*conn.MConnection, types.NodeInfo, crypto.PubKey, error) { 332 if c.mconn != nil { 333 return nil, types.NodeInfo{}, nil, errors.New("connection is already handshaked") 334 } 335 336 secretConn, err := conn.MakeSecretConnection(c.conn, privKey) 337 if err != nil { 338 return nil, types.NodeInfo{}, nil, err 339 } 340 341 wg := &sync.WaitGroup{} 342 var pbPeerInfo p2pproto.NodeInfo 343 errCh := make(chan error, 2) 344 wg.Add(1) 345 go func() { 346 defer wg.Done() 347 _, err := protoio.NewDelimitedWriter(secretConn).WriteMsg(nodeInfo.ToProto()) 348 select { 349 case errCh <- err: 350 case <-ctx.Done(): 351 } 352 353 }() 354 wg.Add(1) 355 go func() { 356 defer wg.Done() 357 _, err := protoio.NewDelimitedReader(secretConn, types.MaxNodeInfoSize()).ReadMsg(&pbPeerInfo) 358 select { 359 case errCh <- err: 360 case <-ctx.Done(): 361 } 362 }() 363 364 wg.Wait() 365 366 if err, ok := <-errCh; ok && err != nil { 367 return nil, types.NodeInfo{}, nil, err 368 } 369 370 if err := ctx.Err(); err != nil { 371 return nil, types.NodeInfo{}, nil, err 372 } 373 374 peerInfo, err := types.NodeInfoFromProto(&pbPeerInfo) 375 if err != nil { 376 return nil, types.NodeInfo{}, nil, err 377 } 378 379 c.logger.Debug(fmt.Sprintf("Creating a new MConnection with peerId %s, moniker %s, listenAddr %s", peerInfo.NodeID, peerInfo.Moniker, peerInfo.ListenAddr)) 380 381 mconn := conn.NewMConnection( 382 c.logger.With("peer", c.RemoteEndpoint().NodeAddress(peerInfo.NodeID)), 383 secretConn, 384 c.channelDescs, 385 c.onReceive, 386 c.onError, 387 c.mConnConfig, 388 ) 389 390 return mconn, peerInfo, secretConn.RemotePubKey(), nil 391 } 392 393 // onReceive is a callback for MConnection received messages. 394 func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload []byte) { 395 select { 396 case c.receiveCh <- mConnMessage{channelID: chID, payload: payload}: 397 case <-ctx.Done(): 398 } 399 } 400 401 // onError is a callback for MConnection errors. The error is passed via errorCh 402 // to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior). 403 func (c *mConnConnection) onError(ctx context.Context, e interface{}) { 404 err, ok := e.(error) 405 if !ok { 406 err = fmt.Errorf("%v", err) 407 } 408 // We have to close the connection here, since MConnection will have stopped 409 // the service on any errors. 410 _ = c.Close() 411 select { 412 case c.errorCh <- err: 413 c.logger.Error(fmt.Sprintf("mConnection Error %s", err)) 414 case <-ctx.Done(): 415 } 416 } 417 418 // String displays connection information. 419 func (c *mConnConnection) String() string { 420 return c.RemoteEndpoint().String() 421 } 422 423 // SendMessage implements Connection. 424 func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error { 425 if chID > math.MaxUint8 { 426 return fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID) 427 } 428 select { 429 case err := <-c.errorCh: 430 return err 431 case <-ctx.Done(): 432 return io.EOF 433 default: 434 if ok := c.mconn.Send(chID, msg); !ok { 435 return errors.New("sending message timed out") 436 } 437 438 return nil 439 } 440 } 441 442 // ReceiveMessage implements Connection. 443 func (c *mConnConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) { 444 select { 445 case err := <-c.errorCh: 446 return 0, nil, err 447 case <-c.doneCh: 448 return 0, nil, io.EOF 449 case <-ctx.Done(): 450 return 0, nil, io.EOF 451 case msg := <-c.receiveCh: 452 return msg.channelID, msg.payload, nil 453 } 454 } 455 456 // LocalEndpoint implements Connection. 457 func (c *mConnConnection) LocalEndpoint() Endpoint { 458 endpoint := Endpoint{ 459 Protocol: MConnProtocol, 460 } 461 if addr, ok := c.conn.LocalAddr().(*net.TCPAddr); ok { 462 endpoint.IP = addr.IP 463 endpoint.Port = uint16(addr.Port) 464 } 465 return endpoint 466 } 467 468 // RemoteEndpoint implements Connection. 469 func (c *mConnConnection) RemoteEndpoint() Endpoint { 470 endpoint := Endpoint{ 471 Protocol: MConnProtocol, 472 } 473 if addr, ok := c.conn.RemoteAddr().(*net.TCPAddr); ok { 474 endpoint.IP = addr.IP 475 endpoint.Port = uint16(addr.Port) 476 } 477 return endpoint 478 } 479 480 // Close implements Connection. 481 func (c *mConnConnection) Close() error { 482 var err error 483 c.closeOnce.Do(func() { 484 defer close(c.doneCh) 485 486 if c.mconn != nil && c.mconn.IsRunning() { 487 c.mconn.Stop() 488 } else { 489 err = c.conn.Close() 490 } 491 }) 492 return err 493 }