github.com/godevsig/adaptiveservice@v0.9.23/streamtransport.go (about)

     1  package adaptiveservice
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"reflect"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"unsafe"
    14  
    15  	"github.com/niubaoshu/gotiny"
    16  )
    17  
    18  type streamTransport struct {
    19  	svc              *service
    20  	closed           chan struct{}
    21  	lnr              net.Listener
    22  	reverseProxyConn Connection
    23  	chanNetConn      chan net.Conn
    24  }
    25  
    26  func makeStreamTransport(svc *service, lnr net.Listener) *streamTransport {
    27  	return &streamTransport{
    28  		svc:         svc,
    29  		closed:      make(chan struct{}),
    30  		lnr:         lnr,
    31  		chanNetConn: make(chan net.Conn, 8),
    32  	}
    33  }
    34  
    35  func (svc *service) newUDSTransport() (*streamTransport, error) {
    36  	addr := toUDSAddr(svc.publisherName, svc.serviceName)
    37  	lnr, err := net.Listen("unix", addr)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	st := makeStreamTransport(svc, lnr)
    43  	go st.receiver()
    44  	svc.s.lg.Infof("service %s %s listening on %s", svc.publisherName, svc.serviceName, addr)
    45  	return st, nil
    46  }
    47  
    48  func connectReverseProxy(svc *service) Connection {
    49  	c := NewClient(WithScope(ScopeLAN|ScopeWAN),
    50  		WithLogger(svc.s.lg),
    51  		WithRegistryAddr(svc.s.registryAddr),
    52  		WithProviderID(svc.s.providerID),
    53  	).SetDiscoverTimeout(3)
    54  	connChan := c.Discover(BuiltinPublisher, SrvReverseProxy, "*")
    55  	defer func() {
    56  		for conn := range connChan {
    57  			conn.Close()
    58  		}
    59  	}()
    60  	for conn := range connChan {
    61  		err := conn.SendRecv(&proxyRegServiceInWAN{svc.publisherName, svc.serviceName, svc.s.providerID}, nil)
    62  		if err == nil {
    63  			return conn
    64  		}
    65  		conn.Close()
    66  	}
    67  	return nil
    68  }
    69  
    70  func (svc *service) newTCPTransport(onPort string) (*streamTransport, error) {
    71  	if len(onPort) == 0 {
    72  		onPort = "0"
    73  	}
    74  	lnr, err := net.Listen("tcp", ":"+onPort)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	st := makeStreamTransport(svc, lnr)
    80  	go st.receiver()
    81  	addr := lnr.Addr().String()
    82  	_, port, _ := net.SplitHostPort(addr) // from [::]:43807
    83  	svc.s.lg.Infof("service %s %s listening on %s", svc.publisherName, svc.serviceName, addr)
    84  
    85  	if svc.scope&ScopeLAN == ScopeLAN {
    86  		if err := svc.regServiceLAN(port); err != nil {
    87  			svc.s.lg.Warnf("service %s %s register to LAN failed: %v", svc.publisherName, svc.serviceName, err)
    88  			st.close()
    89  			return nil, err
    90  		}
    91  		svc.s.lg.Infof("service %s %s registered to LAN", svc.publisherName, svc.serviceName)
    92  	}
    93  
    94  	if svc.scope&ScopeWAN == ScopeWAN {
    95  		if err := svc.regServiceWAN(port); err != nil {
    96  			svc.s.lg.Infof("service %s %s can not register to WAN directly: %v", svc.publisherName, svc.serviceName, err)
    97  			st.reverseProxyConn = connectReverseProxy(svc)
    98  			if st.reverseProxyConn == nil {
    99  				svc.s.lg.Warnf("service %s %s register to proxy failed", svc.publisherName, svc.serviceName)
   100  				st.close()
   101  				return nil, err
   102  			}
   103  			svc.s.lg.Infof("reverse proxy connected")
   104  			go st.reverseReceiver()
   105  		}
   106  		svc.s.lg.Infof("service %s %s registered to WAN", svc.publisherName, svc.serviceName)
   107  	}
   108  	return st, nil
   109  }
   110  
   111  func (st *streamTransport) close() {
   112  	closed := st.closed
   113  	if st.closed == nil {
   114  		return
   115  	}
   116  	st.closed = nil
   117  	close(closed)
   118  	svc := st.svc
   119  	svc.s.lg.Debugf("stream transport %s closing", st.lnr.Addr().String())
   120  	st.lnr.Close()
   121  	if st.reverseProxyConn != nil {
   122  		st.reverseProxyConn.Close()
   123  	}
   124  	if st.lnr.Addr().Network() == "unix" {
   125  		return
   126  	}
   127  	if svc.scope&ScopeLAN == ScopeLAN {
   128  		if err := svc.delServiceLAN(); err != nil {
   129  			svc.s.lg.Warnf("del service in lan failed: %v", err)
   130  		}
   131  	}
   132  	if svc.scope&ScopeWAN == ScopeWAN {
   133  		if err := svc.delServiceWAN(); err != nil {
   134  			svc.s.lg.Warnf("del service in wan failed: %v", err)
   135  		}
   136  	}
   137  }
   138  
   139  type streamTransportMsg struct {
   140  	chanID uint64 // client stream channel ID
   141  	msg    interface{}
   142  }
   143  
   144  type streamServerStream struct {
   145  	Context
   146  	mtx         *sync.Mutex
   147  	lg          Logger
   148  	netconn     net.Conn
   149  	connClose   *chan struct{}
   150  	privateChan chan interface{} // dedicated to the client
   151  	chanID      uint64           // client stream channel ID, taken from transport msg
   152  	enc         *gotiny.Encoder
   153  	encMainCopy int32
   154  	timeouter
   155  }
   156  
   157  func (ss *streamServerStream) GetNetconn() Netconn {
   158  	return ss.netconn
   159  }
   160  
   161  func (ss *streamServerStream) send(tm *streamTransportMsg) error {
   162  	buf := net.Buffers{}
   163  	mainCopy := false
   164  	if atomic.CompareAndSwapInt32(&ss.encMainCopy, 0, 1) {
   165  		ss.lg.Debugf("enc main copy")
   166  		mainCopy = true
   167  	}
   168  	enc := ss.enc
   169  	if !mainCopy {
   170  		enc = enc.Copy()
   171  	}
   172  	bufMsg := enc.Encode(tm)
   173  	bufSize := make([]byte, 4)
   174  	binary.BigEndian.PutUint32(bufSize, uint32(len(bufMsg)))
   175  	buf = append(buf, bufSize, bufMsg)
   176  	ss.lg.Debugf("stream server send: tm: %#v ==> size %d, buf %v <%s>", tm, len(bufMsg), bufMsg, bufMsg)
   177  	ss.mtx.Lock()
   178  	defer func() {
   179  		if mainCopy {
   180  			atomic.StoreInt32(&ss.encMainCopy, 0)
   181  		}
   182  		ss.mtx.Unlock()
   183  	}()
   184  	if _, err := buf.WriteTo(ss.netconn); err != nil {
   185  		return err
   186  	}
   187  	return nil
   188  }
   189  
   190  func (ss *streamServerStream) Send(msg interface{}) error {
   191  	if *ss.connClose == nil {
   192  		return io.EOF
   193  	}
   194  	tm := streamTransportMsg{chanID: ss.chanID, msg: msg}
   195  	return ss.send(&tm)
   196  }
   197  
   198  func (ss *streamServerStream) Recv(msgPtr interface{}) (err error) {
   199  	connClose := *ss.connClose
   200  	if *ss.connClose == nil {
   201  		return io.EOF
   202  	}
   203  	rptr := reflect.ValueOf(msgPtr)
   204  	if msgPtr != nil && (rptr.Kind() != reflect.Ptr || rptr.IsNil()) {
   205  		panic("not a pointer or nil pointer")
   206  	}
   207  
   208  	select {
   209  	case <-connClose:
   210  		return io.EOF
   211  	case <-ss.timeouter.timeoutChan():
   212  		return ErrRecvTimeout
   213  	case msg := <-ss.privateChan:
   214  		if err, ok := msg.(error); ok {
   215  			return err
   216  		}
   217  		if msgPtr == nil { // msgPtr is nil
   218  			return nil // user just looks at error, no error here
   219  		}
   220  
   221  		rv := rptr.Elem()
   222  		mrv := reflect.ValueOf(msg)
   223  		if rv.Kind() != reflect.Ptr && mrv.Kind() == reflect.Ptr {
   224  			mrv = mrv.Elem()
   225  		}
   226  		defer func() {
   227  			if e := recover(); e != nil {
   228  				err = fmt.Errorf("message type mismatch: %v", e)
   229  			}
   230  		}()
   231  		rv.Set(mrv)
   232  	}
   233  
   234  	return
   235  }
   236  
   237  func (ss *streamServerStream) SendRecv(msgSnd interface{}, msgRcvPtr interface{}) error {
   238  	if err := ss.Send(msgSnd); err != nil {
   239  		return err
   240  	}
   241  	if err := ss.Recv(msgRcvPtr); err != nil {
   242  		return err
   243  	}
   244  	return nil
   245  }
   246  
   247  func (st *streamTransport) reverseReceiver() {
   248  	lg := st.svc.s.lg
   249  	cmdConn := st.reverseProxyConn
   250  	defer cmdConn.Close()
   251  	conn := cmdConn.(*streamConnection)
   252  	host, _, _ := net.SplitHostPort(conn.netconn.RemoteAddr().String()) // from [::]:43807
   253  
   254  	for {
   255  		var port string
   256  		if err := cmdConn.Recv(&port); err != nil {
   257  			lg.Warnf("reverseReceiver: cmd connection broken: %v", err)
   258  			break
   259  		}
   260  		lg.Debugf("reverseReceiver: new reverse connection request")
   261  		addr := host + ":" + port
   262  		netconn, err := net.Dial("tcp", addr)
   263  		if err != nil {
   264  			lg.Warnf("reverseReceiver: reverse connection failed: %v", err)
   265  			break
   266  		}
   267  		st.chanNetConn <- netconn
   268  	}
   269  	lg.Infof("reverse proxy lost, reconnecting")
   270  	st.reverseProxyConn = connectReverseProxy(st.svc)
   271  	if st.reverseProxyConn == nil {
   272  		lg.Errorf("service %s %s lost connection to reverse proxy", st.svc.publisherName, st.svc.serviceName)
   273  		return
   274  	}
   275  	lg.Infof("reverse proxy reconnected")
   276  	go st.reverseReceiver()
   277  }
   278  
   279  func (st *streamTransport) receiver() {
   280  	lg := st.svc.s.lg
   281  	mq := st.svc.s.mq
   282  	lnr := st.lnr
   283  
   284  	go func() {
   285  		rootRegistryIP, _, _ := net.SplitHostPort(st.svc.s.registryAddr)
   286  		pinged := false
   287  		if lnr.Addr().Network() == "unix" {
   288  			pinged = true
   289  		}
   290  		if st.svc.publisherName == BuiltinPublisher && st.svc.serviceName == "rootRegistry" {
   291  			pinged = true
   292  		}
   293  		for {
   294  			netconn, err := lnr.Accept()
   295  			if err != nil {
   296  				lg.Warnf("stream transport listener: %v", err)
   297  				// the streamTransport has been closed
   298  				if st.closed == nil {
   299  					return
   300  				}
   301  				continue
   302  			}
   303  			if !pinged {
   304  				host, _, _ := net.SplitHostPort(netconn.RemoteAddr().String())
   305  				if host == rootRegistryIP {
   306  					netconn.Close()
   307  					pinged = true
   308  					lg.Debugf("ping from root registry")
   309  					continue
   310  				}
   311  			}
   312  			st.chanNetConn <- netconn
   313  		}
   314  	}()
   315  
   316  	handleConn := func(netconn net.Conn) {
   317  		lg.Debugf("%s %s new stream connection from: %s", st.svc.publisherName, st.svc.serviceName, netconn.RemoteAddr().String())
   318  		if st.svc.fnOnConnect != nil {
   319  			lg.Debugf("%s %s on connect", st.svc.publisherName, st.svc.serviceName)
   320  			if st.svc.fnOnConnect(netconn) {
   321  				return
   322  			}
   323  		}
   324  
   325  		connClose := make(chan struct{})
   326  		defer func() {
   327  			if st.svc.fnOnDisconnect != nil {
   328  				lg.Debugf("%s %s on disconnect", st.svc.publisherName, st.svc.serviceName)
   329  				st.svc.fnOnDisconnect(netconn)
   330  			}
   331  			lg.Debugf("%s %s stream connection disconnected: %s", st.svc.publisherName, st.svc.serviceName, netconn.RemoteAddr().String())
   332  			close(connClose)
   333  			connClose = nil
   334  			netconn.Close()
   335  		}()
   336  		var mtx sync.Mutex
   337  		ssMap := make(map[uint64]*streamServerStream)
   338  		dec := gotiny.NewDecoderWithPtr((*streamTransportMsg)(nil))
   339  		dec.SetCopyMode()
   340  		bufSize := make([]byte, 4)
   341  		bufMsg := make([]byte, 512)
   342  		for {
   343  			if st.closed == nil {
   344  				return
   345  			}
   346  			if _, err := io.ReadFull(netconn, bufSize); err != nil {
   347  				if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") {
   348  					lg.Debugf("stream sever receiver: connection closed: %v", err)
   349  				} else {
   350  					lg.Warnf("stream sever receiver: from %s read size error: %v", netconn.RemoteAddr().String(), err)
   351  				}
   352  				return
   353  			}
   354  
   355  			size := binary.BigEndian.Uint32(bufSize)
   356  			bufCap := uint32(cap(bufMsg))
   357  			if size <= bufCap {
   358  				bufMsg = bufMsg[:size]
   359  			} else {
   360  				bufMsg = make([]byte, size)
   361  			}
   362  			if _, err := io.ReadFull(netconn, bufMsg); err != nil {
   363  				lg.Warnf("stream sever receiver: from %s read buf error: %v", netconn.RemoteAddr().String(), err)
   364  				return
   365  			}
   366  
   367  			var decErr error
   368  			var tm streamTransportMsg
   369  			func() {
   370  				defer func() {
   371  					if e := recover(); e != nil {
   372  						decErr = fmt.Errorf("unknown message: %v", e)
   373  						lg.Errorf("%v", decErr)
   374  						tm.msg = decErr
   375  					}
   376  				}()
   377  				dec.Decode(bufMsg, &tm)
   378  			}()
   379  			lg.Debugf("stream server receiver: tm: %#v", &tm)
   380  
   381  			ss := ssMap[tm.chanID]
   382  			if ss == nil {
   383  				ss = &streamServerStream{
   384  					Context:     &contextImpl{},
   385  					mtx:         &mtx,
   386  					lg:          lg,
   387  					netconn:     netconn,
   388  					connClose:   &connClose,
   389  					privateChan: make(chan interface{}, st.svc.s.qsize),
   390  					chanID:      tm.chanID,
   391  					enc:         gotiny.NewEncoderWithPtr((*streamTransportMsg)(nil)),
   392  				}
   393  				ssMap[tm.chanID] = ss
   394  				if st.svc.fnOnNewStream != nil {
   395  					lg.Debugf("%s %s on new stream", st.svc.publisherName, st.svc.serviceName)
   396  					st.svc.fnOnNewStream(ss)
   397  				}
   398  			}
   399  
   400  			if decErr != nil {
   401  				if err := ss.Send(decErr); err != nil {
   402  					lg.Errorf("send decode error failed: %v", err)
   403  				}
   404  				continue
   405  			}
   406  
   407  			if st.svc.canHandle(tm.msg) {
   408  				mm := &metaKnownMsg{
   409  					stream: ss,
   410  					msg:    tm.msg.(KnownMessage),
   411  				}
   412  				mq.putMetaMsg(mm)
   413  			} else {
   414  				ss.privateChan <- tm.msg
   415  			}
   416  		}
   417  	}
   418  
   419  	closed := st.closed
   420  	for {
   421  		select {
   422  		case <-closed:
   423  			return
   424  		case netconn := <-st.chanNetConn:
   425  			go handleConn(netconn)
   426  		}
   427  	}
   428  }
   429  
   430  // below for client side
   431  
   432  type streamClientStream struct {
   433  	conn        *streamConnection
   434  	msgChan     chan interface{}
   435  	encMainCopy int32
   436  	enc         *gotiny.Encoder
   437  	timeouter
   438  }
   439  
   440  // stream connection for client.
   441  type streamConnection struct {
   442  	Stream
   443  	sync.Mutex
   444  	owner   *Client
   445  	netconn net.Conn
   446  	closed  chan struct{}
   447  }
   448  
   449  func (c *Client) newStreamConnection(network string, addr string) (*streamConnection, error) {
   450  	proxied := false
   451  	if addr[len(addr)-1] == 'P' {
   452  		c.lg.Debugf("%s is proxied", addr)
   453  		addr = addr[:len(addr)-1]
   454  		proxied = true
   455  	}
   456  	netconn, err := net.Dial(network, addr)
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  	if proxied {
   461  		if _, err := netconn.Read([]byte{0}); err != nil {
   462  			return nil, err
   463  		}
   464  	}
   465  	c.lg.Debugf("stream connection established: %s -> %s", netconn.LocalAddr().String(), addr)
   466  
   467  	conn := &streamConnection{
   468  		owner:   c,
   469  		netconn: netconn,
   470  		closed:  make(chan struct{}),
   471  	}
   472  
   473  	conn.Stream = conn.NewStream()
   474  	go conn.receiver()
   475  	return conn, nil
   476  }
   477  
   478  func (c *Client) newUDSConnection(addr string) (*streamConnection, error) {
   479  	return c.newStreamConnection("unix", addr)
   480  }
   481  func (c *Client) newTCPConnection(addr string) (*streamConnection, error) {
   482  	return c.newStreamConnection("tcp", addr)
   483  }
   484  
   485  func (conn *streamConnection) receiver() {
   486  	defer func() {
   487  		close(conn.closed)
   488  		conn.closed = nil
   489  	}()
   490  	lg := conn.owner.lg
   491  	netconn := conn.netconn
   492  	dec := gotiny.NewDecoderWithPtr((*streamTransportMsg)(nil))
   493  	dec.SetCopyMode()
   494  	bufSize := make([]byte, 4)
   495  	bufMsg := make([]byte, 512)
   496  	for {
   497  		if _, err := io.ReadFull(netconn, bufSize); err != nil {
   498  			if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") {
   499  				lg.Debugf("stream client receiver: connection closed: %v", err)
   500  			} else {
   501  				lg.Warnf("stream client receiver: read size error: %v", err)
   502  			}
   503  			return
   504  		}
   505  
   506  		size := binary.BigEndian.Uint32(bufSize)
   507  		bufCap := uint32(cap(bufMsg))
   508  		if size <= bufCap {
   509  			bufMsg = bufMsg[:size]
   510  		} else {
   511  			bufMsg = make([]byte, size)
   512  		}
   513  		if _, err := io.ReadFull(netconn, bufMsg); err != nil {
   514  			lg.Warnf("stream client receiver: read buf error: %v", err)
   515  			return
   516  		}
   517  
   518  		var tm streamTransportMsg
   519  		//escapes(&tm)
   520  		func() {
   521  			defer func() {
   522  				if e := recover(); e != nil {
   523  					err := fmt.Errorf("unknown message: %v", e)
   524  					lg.Errorf("%v", err)
   525  					tm.msg = err
   526  				}
   527  			}()
   528  			dec.Decode(bufMsg, &tm)
   529  		}()
   530  		lg.Debugf("stream client receiver: tm: %#v", &tm)
   531  
   532  		if tm.chanID != 0 {
   533  			func() {
   534  				defer func() {
   535  					if err := recover(); err != nil {
   536  						lg.Errorf("broken stream chan: %v", err)
   537  					}
   538  				}()
   539  				msgChan := *(*chan interface{})(unsafe.Pointer(uintptr(tm.chanID)))
   540  				msgChan <- tm.msg
   541  			}()
   542  		} else {
   543  			panic("msg channel not specified")
   544  		}
   545  	}
   546  }
   547  
   548  func (conn *streamConnection) NewStream() Stream {
   549  	return &streamClientStream{
   550  		conn:    conn,
   551  		msgChan: make(chan interface{}, conn.owner.qsize),
   552  		enc:     gotiny.NewEncoderWithPtr((*streamTransportMsg)(nil)),
   553  	}
   554  }
   555  func (conn *streamConnection) Close() {
   556  	conn.netconn.Close()
   557  }
   558  
   559  func (cs *streamClientStream) GetNetconn() Netconn {
   560  	return cs.conn.netconn
   561  }
   562  
   563  func (cs *streamClientStream) Send(msg interface{}) error {
   564  	if cs.msgChan == nil || cs.conn.closed == nil {
   565  		return io.EOF
   566  	}
   567  
   568  	cid := uint64(uintptr(unsafe.Pointer(&cs.msgChan)))
   569  	tm := streamTransportMsg{chanID: cid, msg: msg}
   570  
   571  	lg := cs.conn.owner.lg
   572  	buf := net.Buffers{}
   573  	mainCopy := false
   574  	if atomic.CompareAndSwapInt32(&cs.encMainCopy, 0, 1) {
   575  		lg.Debugf("enc main copy")
   576  		mainCopy = true
   577  	}
   578  	enc := cs.enc
   579  	if !mainCopy {
   580  		enc = enc.Copy()
   581  	}
   582  	bufMsg := enc.Encode(&tm)
   583  	bufSize := make([]byte, 4)
   584  	binary.BigEndian.PutUint32(bufSize, uint32(len(bufMsg)))
   585  	buf = append(buf, bufSize, bufMsg)
   586  	lg.Debugf("stream client send: tm: %#v ==> size %d, buf %v <%s>", &tm, len(bufMsg), bufMsg, bufMsg)
   587  	cs.conn.Lock()
   588  	defer func() {
   589  		if mainCopy {
   590  			atomic.StoreInt32(&cs.encMainCopy, 0)
   591  		}
   592  		cs.conn.Unlock()
   593  	}()
   594  	if _, err := buf.WriteTo(cs.conn.netconn); err != nil {
   595  		return err
   596  	}
   597  	return nil
   598  }
   599  
   600  func (cs *streamClientStream) Recv(msgPtr interface{}) (err error) {
   601  	connClosed := cs.conn.closed
   602  	if cs.msgChan == nil || cs.conn.closed == nil {
   603  		return io.EOF
   604  	}
   605  
   606  	rptr := reflect.ValueOf(msgPtr)
   607  	if msgPtr != nil && (rptr.Kind() != reflect.Ptr || rptr.IsNil()) {
   608  		panic("not a pointer or nil pointer")
   609  	}
   610  
   611  	select {
   612  	case <-connClosed:
   613  		return ErrConnReset
   614  	case <-cs.timeouter.timeoutChan():
   615  		return ErrRecvTimeout
   616  	case msg := <-cs.msgChan:
   617  		if err, ok := msg.(error); ok { // message handler returned error
   618  			if err == io.EOF {
   619  				cs.msgChan = nil
   620  			}
   621  			return err
   622  		}
   623  
   624  		if msgPtr == nil { // msgPtr is nil
   625  			return nil // user just looks at error, no error here
   626  		}
   627  
   628  		rv := rptr.Elem()
   629  		mrv := reflect.ValueOf(msg)
   630  		if rv.Kind() != reflect.Ptr && mrv.Kind() == reflect.Ptr {
   631  			mrv = mrv.Elem()
   632  		}
   633  		defer func() {
   634  			if e := recover(); e != nil {
   635  				err = fmt.Errorf("message type mismatch: %v", e)
   636  			}
   637  		}()
   638  		rv.Set(mrv)
   639  	}
   640  
   641  	return
   642  }
   643  
   644  func (cs *streamClientStream) SendRecv(msgSnd interface{}, msgRcvPtr interface{}) error {
   645  	if err := cs.Send(msgSnd); err != nil {
   646  		return err
   647  	}
   648  	if err := cs.Recv(msgRcvPtr); err != nil {
   649  		return err
   650  	}
   651  	return nil
   652  }