github.com/godevsig/adaptiveservice@v0.9.23/streamtransport.go (about) 1 package adaptiveservice 2 3 import ( 4 "encoding/binary" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "reflect" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "unsafe" 14 15 "github.com/niubaoshu/gotiny" 16 ) 17 18 type streamTransport struct { 19 svc *service 20 closed chan struct{} 21 lnr net.Listener 22 reverseProxyConn Connection 23 chanNetConn chan net.Conn 24 } 25 26 func makeStreamTransport(svc *service, lnr net.Listener) *streamTransport { 27 return &streamTransport{ 28 svc: svc, 29 closed: make(chan struct{}), 30 lnr: lnr, 31 chanNetConn: make(chan net.Conn, 8), 32 } 33 } 34 35 func (svc *service) newUDSTransport() (*streamTransport, error) { 36 addr := toUDSAddr(svc.publisherName, svc.serviceName) 37 lnr, err := net.Listen("unix", addr) 38 if err != nil { 39 return nil, err 40 } 41 42 st := makeStreamTransport(svc, lnr) 43 go st.receiver() 44 svc.s.lg.Infof("service %s %s listening on %s", svc.publisherName, svc.serviceName, addr) 45 return st, nil 46 } 47 48 func connectReverseProxy(svc *service) Connection { 49 c := NewClient(WithScope(ScopeLAN|ScopeWAN), 50 WithLogger(svc.s.lg), 51 WithRegistryAddr(svc.s.registryAddr), 52 WithProviderID(svc.s.providerID), 53 ).SetDiscoverTimeout(3) 54 connChan := c.Discover(BuiltinPublisher, SrvReverseProxy, "*") 55 defer func() { 56 for conn := range connChan { 57 conn.Close() 58 } 59 }() 60 for conn := range connChan { 61 err := conn.SendRecv(&proxyRegServiceInWAN{svc.publisherName, svc.serviceName, svc.s.providerID}, nil) 62 if err == nil { 63 return conn 64 } 65 conn.Close() 66 } 67 return nil 68 } 69 70 func (svc *service) newTCPTransport(onPort string) (*streamTransport, error) { 71 if len(onPort) == 0 { 72 onPort = "0" 73 } 74 lnr, err := net.Listen("tcp", ":"+onPort) 75 if err != nil { 76 return nil, err 77 } 78 79 st := makeStreamTransport(svc, lnr) 80 go st.receiver() 81 addr := lnr.Addr().String() 82 _, port, _ := net.SplitHostPort(addr) // from [::]:43807 83 svc.s.lg.Infof("service %s %s listening on %s", svc.publisherName, svc.serviceName, addr) 84 85 if svc.scope&ScopeLAN == ScopeLAN { 86 if err := svc.regServiceLAN(port); err != nil { 87 svc.s.lg.Warnf("service %s %s register to LAN failed: %v", svc.publisherName, svc.serviceName, err) 88 st.close() 89 return nil, err 90 } 91 svc.s.lg.Infof("service %s %s registered to LAN", svc.publisherName, svc.serviceName) 92 } 93 94 if svc.scope&ScopeWAN == ScopeWAN { 95 if err := svc.regServiceWAN(port); err != nil { 96 svc.s.lg.Infof("service %s %s can not register to WAN directly: %v", svc.publisherName, svc.serviceName, err) 97 st.reverseProxyConn = connectReverseProxy(svc) 98 if st.reverseProxyConn == nil { 99 svc.s.lg.Warnf("service %s %s register to proxy failed", svc.publisherName, svc.serviceName) 100 st.close() 101 return nil, err 102 } 103 svc.s.lg.Infof("reverse proxy connected") 104 go st.reverseReceiver() 105 } 106 svc.s.lg.Infof("service %s %s registered to WAN", svc.publisherName, svc.serviceName) 107 } 108 return st, nil 109 } 110 111 func (st *streamTransport) close() { 112 closed := st.closed 113 if st.closed == nil { 114 return 115 } 116 st.closed = nil 117 close(closed) 118 svc := st.svc 119 svc.s.lg.Debugf("stream transport %s closing", st.lnr.Addr().String()) 120 st.lnr.Close() 121 if st.reverseProxyConn != nil { 122 st.reverseProxyConn.Close() 123 } 124 if st.lnr.Addr().Network() == "unix" { 125 return 126 } 127 if svc.scope&ScopeLAN == ScopeLAN { 128 if err := svc.delServiceLAN(); err != nil { 129 svc.s.lg.Warnf("del service in lan failed: %v", err) 130 } 131 } 132 if svc.scope&ScopeWAN == ScopeWAN { 133 if err := svc.delServiceWAN(); err != nil { 134 svc.s.lg.Warnf("del service in wan failed: %v", err) 135 } 136 } 137 } 138 139 type streamTransportMsg struct { 140 chanID uint64 // client stream channel ID 141 msg interface{} 142 } 143 144 type streamServerStream struct { 145 Context 146 mtx *sync.Mutex 147 lg Logger 148 netconn net.Conn 149 connClose *chan struct{} 150 privateChan chan interface{} // dedicated to the client 151 chanID uint64 // client stream channel ID, taken from transport msg 152 enc *gotiny.Encoder 153 encMainCopy int32 154 timeouter 155 } 156 157 func (ss *streamServerStream) GetNetconn() Netconn { 158 return ss.netconn 159 } 160 161 func (ss *streamServerStream) send(tm *streamTransportMsg) error { 162 buf := net.Buffers{} 163 mainCopy := false 164 if atomic.CompareAndSwapInt32(&ss.encMainCopy, 0, 1) { 165 ss.lg.Debugf("enc main copy") 166 mainCopy = true 167 } 168 enc := ss.enc 169 if !mainCopy { 170 enc = enc.Copy() 171 } 172 bufMsg := enc.Encode(tm) 173 bufSize := make([]byte, 4) 174 binary.BigEndian.PutUint32(bufSize, uint32(len(bufMsg))) 175 buf = append(buf, bufSize, bufMsg) 176 ss.lg.Debugf("stream server send: tm: %#v ==> size %d, buf %v <%s>", tm, len(bufMsg), bufMsg, bufMsg) 177 ss.mtx.Lock() 178 defer func() { 179 if mainCopy { 180 atomic.StoreInt32(&ss.encMainCopy, 0) 181 } 182 ss.mtx.Unlock() 183 }() 184 if _, err := buf.WriteTo(ss.netconn); err != nil { 185 return err 186 } 187 return nil 188 } 189 190 func (ss *streamServerStream) Send(msg interface{}) error { 191 if *ss.connClose == nil { 192 return io.EOF 193 } 194 tm := streamTransportMsg{chanID: ss.chanID, msg: msg} 195 return ss.send(&tm) 196 } 197 198 func (ss *streamServerStream) Recv(msgPtr interface{}) (err error) { 199 connClose := *ss.connClose 200 if *ss.connClose == nil { 201 return io.EOF 202 } 203 rptr := reflect.ValueOf(msgPtr) 204 if msgPtr != nil && (rptr.Kind() != reflect.Ptr || rptr.IsNil()) { 205 panic("not a pointer or nil pointer") 206 } 207 208 select { 209 case <-connClose: 210 return io.EOF 211 case <-ss.timeouter.timeoutChan(): 212 return ErrRecvTimeout 213 case msg := <-ss.privateChan: 214 if err, ok := msg.(error); ok { 215 return err 216 } 217 if msgPtr == nil { // msgPtr is nil 218 return nil // user just looks at error, no error here 219 } 220 221 rv := rptr.Elem() 222 mrv := reflect.ValueOf(msg) 223 if rv.Kind() != reflect.Ptr && mrv.Kind() == reflect.Ptr { 224 mrv = mrv.Elem() 225 } 226 defer func() { 227 if e := recover(); e != nil { 228 err = fmt.Errorf("message type mismatch: %v", e) 229 } 230 }() 231 rv.Set(mrv) 232 } 233 234 return 235 } 236 237 func (ss *streamServerStream) SendRecv(msgSnd interface{}, msgRcvPtr interface{}) error { 238 if err := ss.Send(msgSnd); err != nil { 239 return err 240 } 241 if err := ss.Recv(msgRcvPtr); err != nil { 242 return err 243 } 244 return nil 245 } 246 247 func (st *streamTransport) reverseReceiver() { 248 lg := st.svc.s.lg 249 cmdConn := st.reverseProxyConn 250 defer cmdConn.Close() 251 conn := cmdConn.(*streamConnection) 252 host, _, _ := net.SplitHostPort(conn.netconn.RemoteAddr().String()) // from [::]:43807 253 254 for { 255 var port string 256 if err := cmdConn.Recv(&port); err != nil { 257 lg.Warnf("reverseReceiver: cmd connection broken: %v", err) 258 break 259 } 260 lg.Debugf("reverseReceiver: new reverse connection request") 261 addr := host + ":" + port 262 netconn, err := net.Dial("tcp", addr) 263 if err != nil { 264 lg.Warnf("reverseReceiver: reverse connection failed: %v", err) 265 break 266 } 267 st.chanNetConn <- netconn 268 } 269 lg.Infof("reverse proxy lost, reconnecting") 270 st.reverseProxyConn = connectReverseProxy(st.svc) 271 if st.reverseProxyConn == nil { 272 lg.Errorf("service %s %s lost connection to reverse proxy", st.svc.publisherName, st.svc.serviceName) 273 return 274 } 275 lg.Infof("reverse proxy reconnected") 276 go st.reverseReceiver() 277 } 278 279 func (st *streamTransport) receiver() { 280 lg := st.svc.s.lg 281 mq := st.svc.s.mq 282 lnr := st.lnr 283 284 go func() { 285 rootRegistryIP, _, _ := net.SplitHostPort(st.svc.s.registryAddr) 286 pinged := false 287 if lnr.Addr().Network() == "unix" { 288 pinged = true 289 } 290 if st.svc.publisherName == BuiltinPublisher && st.svc.serviceName == "rootRegistry" { 291 pinged = true 292 } 293 for { 294 netconn, err := lnr.Accept() 295 if err != nil { 296 lg.Warnf("stream transport listener: %v", err) 297 // the streamTransport has been closed 298 if st.closed == nil { 299 return 300 } 301 continue 302 } 303 if !pinged { 304 host, _, _ := net.SplitHostPort(netconn.RemoteAddr().String()) 305 if host == rootRegistryIP { 306 netconn.Close() 307 pinged = true 308 lg.Debugf("ping from root registry") 309 continue 310 } 311 } 312 st.chanNetConn <- netconn 313 } 314 }() 315 316 handleConn := func(netconn net.Conn) { 317 lg.Debugf("%s %s new stream connection from: %s", st.svc.publisherName, st.svc.serviceName, netconn.RemoteAddr().String()) 318 if st.svc.fnOnConnect != nil { 319 lg.Debugf("%s %s on connect", st.svc.publisherName, st.svc.serviceName) 320 if st.svc.fnOnConnect(netconn) { 321 return 322 } 323 } 324 325 connClose := make(chan struct{}) 326 defer func() { 327 if st.svc.fnOnDisconnect != nil { 328 lg.Debugf("%s %s on disconnect", st.svc.publisherName, st.svc.serviceName) 329 st.svc.fnOnDisconnect(netconn) 330 } 331 lg.Debugf("%s %s stream connection disconnected: %s", st.svc.publisherName, st.svc.serviceName, netconn.RemoteAddr().String()) 332 close(connClose) 333 connClose = nil 334 netconn.Close() 335 }() 336 var mtx sync.Mutex 337 ssMap := make(map[uint64]*streamServerStream) 338 dec := gotiny.NewDecoderWithPtr((*streamTransportMsg)(nil)) 339 dec.SetCopyMode() 340 bufSize := make([]byte, 4) 341 bufMsg := make([]byte, 512) 342 for { 343 if st.closed == nil { 344 return 345 } 346 if _, err := io.ReadFull(netconn, bufSize); err != nil { 347 if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { 348 lg.Debugf("stream sever receiver: connection closed: %v", err) 349 } else { 350 lg.Warnf("stream sever receiver: from %s read size error: %v", netconn.RemoteAddr().String(), err) 351 } 352 return 353 } 354 355 size := binary.BigEndian.Uint32(bufSize) 356 bufCap := uint32(cap(bufMsg)) 357 if size <= bufCap { 358 bufMsg = bufMsg[:size] 359 } else { 360 bufMsg = make([]byte, size) 361 } 362 if _, err := io.ReadFull(netconn, bufMsg); err != nil { 363 lg.Warnf("stream sever receiver: from %s read buf error: %v", netconn.RemoteAddr().String(), err) 364 return 365 } 366 367 var decErr error 368 var tm streamTransportMsg 369 func() { 370 defer func() { 371 if e := recover(); e != nil { 372 decErr = fmt.Errorf("unknown message: %v", e) 373 lg.Errorf("%v", decErr) 374 tm.msg = decErr 375 } 376 }() 377 dec.Decode(bufMsg, &tm) 378 }() 379 lg.Debugf("stream server receiver: tm: %#v", &tm) 380 381 ss := ssMap[tm.chanID] 382 if ss == nil { 383 ss = &streamServerStream{ 384 Context: &contextImpl{}, 385 mtx: &mtx, 386 lg: lg, 387 netconn: netconn, 388 connClose: &connClose, 389 privateChan: make(chan interface{}, st.svc.s.qsize), 390 chanID: tm.chanID, 391 enc: gotiny.NewEncoderWithPtr((*streamTransportMsg)(nil)), 392 } 393 ssMap[tm.chanID] = ss 394 if st.svc.fnOnNewStream != nil { 395 lg.Debugf("%s %s on new stream", st.svc.publisherName, st.svc.serviceName) 396 st.svc.fnOnNewStream(ss) 397 } 398 } 399 400 if decErr != nil { 401 if err := ss.Send(decErr); err != nil { 402 lg.Errorf("send decode error failed: %v", err) 403 } 404 continue 405 } 406 407 if st.svc.canHandle(tm.msg) { 408 mm := &metaKnownMsg{ 409 stream: ss, 410 msg: tm.msg.(KnownMessage), 411 } 412 mq.putMetaMsg(mm) 413 } else { 414 ss.privateChan <- tm.msg 415 } 416 } 417 } 418 419 closed := st.closed 420 for { 421 select { 422 case <-closed: 423 return 424 case netconn := <-st.chanNetConn: 425 go handleConn(netconn) 426 } 427 } 428 } 429 430 // below for client side 431 432 type streamClientStream struct { 433 conn *streamConnection 434 msgChan chan interface{} 435 encMainCopy int32 436 enc *gotiny.Encoder 437 timeouter 438 } 439 440 // stream connection for client. 441 type streamConnection struct { 442 Stream 443 sync.Mutex 444 owner *Client 445 netconn net.Conn 446 closed chan struct{} 447 } 448 449 func (c *Client) newStreamConnection(network string, addr string) (*streamConnection, error) { 450 proxied := false 451 if addr[len(addr)-1] == 'P' { 452 c.lg.Debugf("%s is proxied", addr) 453 addr = addr[:len(addr)-1] 454 proxied = true 455 } 456 netconn, err := net.Dial(network, addr) 457 if err != nil { 458 return nil, err 459 } 460 if proxied { 461 if _, err := netconn.Read([]byte{0}); err != nil { 462 return nil, err 463 } 464 } 465 c.lg.Debugf("stream connection established: %s -> %s", netconn.LocalAddr().String(), addr) 466 467 conn := &streamConnection{ 468 owner: c, 469 netconn: netconn, 470 closed: make(chan struct{}), 471 } 472 473 conn.Stream = conn.NewStream() 474 go conn.receiver() 475 return conn, nil 476 } 477 478 func (c *Client) newUDSConnection(addr string) (*streamConnection, error) { 479 return c.newStreamConnection("unix", addr) 480 } 481 func (c *Client) newTCPConnection(addr string) (*streamConnection, error) { 482 return c.newStreamConnection("tcp", addr) 483 } 484 485 func (conn *streamConnection) receiver() { 486 defer func() { 487 close(conn.closed) 488 conn.closed = nil 489 }() 490 lg := conn.owner.lg 491 netconn := conn.netconn 492 dec := gotiny.NewDecoderWithPtr((*streamTransportMsg)(nil)) 493 dec.SetCopyMode() 494 bufSize := make([]byte, 4) 495 bufMsg := make([]byte, 512) 496 for { 497 if _, err := io.ReadFull(netconn, bufSize); err != nil { 498 if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { 499 lg.Debugf("stream client receiver: connection closed: %v", err) 500 } else { 501 lg.Warnf("stream client receiver: read size error: %v", err) 502 } 503 return 504 } 505 506 size := binary.BigEndian.Uint32(bufSize) 507 bufCap := uint32(cap(bufMsg)) 508 if size <= bufCap { 509 bufMsg = bufMsg[:size] 510 } else { 511 bufMsg = make([]byte, size) 512 } 513 if _, err := io.ReadFull(netconn, bufMsg); err != nil { 514 lg.Warnf("stream client receiver: read buf error: %v", err) 515 return 516 } 517 518 var tm streamTransportMsg 519 //escapes(&tm) 520 func() { 521 defer func() { 522 if e := recover(); e != nil { 523 err := fmt.Errorf("unknown message: %v", e) 524 lg.Errorf("%v", err) 525 tm.msg = err 526 } 527 }() 528 dec.Decode(bufMsg, &tm) 529 }() 530 lg.Debugf("stream client receiver: tm: %#v", &tm) 531 532 if tm.chanID != 0 { 533 func() { 534 defer func() { 535 if err := recover(); err != nil { 536 lg.Errorf("broken stream chan: %v", err) 537 } 538 }() 539 msgChan := *(*chan interface{})(unsafe.Pointer(uintptr(tm.chanID))) 540 msgChan <- tm.msg 541 }() 542 } else { 543 panic("msg channel not specified") 544 } 545 } 546 } 547 548 func (conn *streamConnection) NewStream() Stream { 549 return &streamClientStream{ 550 conn: conn, 551 msgChan: make(chan interface{}, conn.owner.qsize), 552 enc: gotiny.NewEncoderWithPtr((*streamTransportMsg)(nil)), 553 } 554 } 555 func (conn *streamConnection) Close() { 556 conn.netconn.Close() 557 } 558 559 func (cs *streamClientStream) GetNetconn() Netconn { 560 return cs.conn.netconn 561 } 562 563 func (cs *streamClientStream) Send(msg interface{}) error { 564 if cs.msgChan == nil || cs.conn.closed == nil { 565 return io.EOF 566 } 567 568 cid := uint64(uintptr(unsafe.Pointer(&cs.msgChan))) 569 tm := streamTransportMsg{chanID: cid, msg: msg} 570 571 lg := cs.conn.owner.lg 572 buf := net.Buffers{} 573 mainCopy := false 574 if atomic.CompareAndSwapInt32(&cs.encMainCopy, 0, 1) { 575 lg.Debugf("enc main copy") 576 mainCopy = true 577 } 578 enc := cs.enc 579 if !mainCopy { 580 enc = enc.Copy() 581 } 582 bufMsg := enc.Encode(&tm) 583 bufSize := make([]byte, 4) 584 binary.BigEndian.PutUint32(bufSize, uint32(len(bufMsg))) 585 buf = append(buf, bufSize, bufMsg) 586 lg.Debugf("stream client send: tm: %#v ==> size %d, buf %v <%s>", &tm, len(bufMsg), bufMsg, bufMsg) 587 cs.conn.Lock() 588 defer func() { 589 if mainCopy { 590 atomic.StoreInt32(&cs.encMainCopy, 0) 591 } 592 cs.conn.Unlock() 593 }() 594 if _, err := buf.WriteTo(cs.conn.netconn); err != nil { 595 return err 596 } 597 return nil 598 } 599 600 func (cs *streamClientStream) Recv(msgPtr interface{}) (err error) { 601 connClosed := cs.conn.closed 602 if cs.msgChan == nil || cs.conn.closed == nil { 603 return io.EOF 604 } 605 606 rptr := reflect.ValueOf(msgPtr) 607 if msgPtr != nil && (rptr.Kind() != reflect.Ptr || rptr.IsNil()) { 608 panic("not a pointer or nil pointer") 609 } 610 611 select { 612 case <-connClosed: 613 return ErrConnReset 614 case <-cs.timeouter.timeoutChan(): 615 return ErrRecvTimeout 616 case msg := <-cs.msgChan: 617 if err, ok := msg.(error); ok { // message handler returned error 618 if err == io.EOF { 619 cs.msgChan = nil 620 } 621 return err 622 } 623 624 if msgPtr == nil { // msgPtr is nil 625 return nil // user just looks at error, no error here 626 } 627 628 rv := rptr.Elem() 629 mrv := reflect.ValueOf(msg) 630 if rv.Kind() != reflect.Ptr && mrv.Kind() == reflect.Ptr { 631 mrv = mrv.Elem() 632 } 633 defer func() { 634 if e := recover(); e != nil { 635 err = fmt.Errorf("message type mismatch: %v", e) 636 } 637 }() 638 rv.Set(mrv) 639 } 640 641 return 642 } 643 644 func (cs *streamClientStream) SendRecv(msgSnd interface{}, msgRcvPtr interface{}) error { 645 if err := cs.Send(msgSnd); err != nil { 646 return err 647 } 648 if err := cs.Recv(msgRcvPtr); err != nil { 649 return err 650 } 651 return nil 652 }