gitee.com/liuxuezhan/go-micro-v1.18.0@v1.0.0/tunnel/session.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"encoding/hex"
     5  	"io"
     6  	"time"
     7  
     8  	"gitee.com/liuxuezhan/go-micro-v1.18.0/transport"
     9  	"gitee.com/liuxuezhan/go-micro-v1.18.0/util/log"
    10  )
    11  
    12  // session is our pseudo session for transport.Socket
    13  type session struct {
    14  	// the tunnel id
    15  	tunnel string
    16  	// the channel name
    17  	channel string
    18  	// the session id based on Micro.Tunnel-Session
    19  	session string
    20  	// token is the session token
    21  	token string
    22  	// closed
    23  	closed chan bool
    24  	// remote addr
    25  	remote string
    26  	// local addr
    27  	local string
    28  	// send chan
    29  	send chan *message
    30  	// recv chan
    31  	recv chan *message
    32  	// if the discovery worked
    33  	discovered bool
    34  	// if the session was accepted
    35  	accepted bool
    36  	// outbound marks the session as outbound dialled connection
    37  	outbound bool
    38  	// lookback marks the session as a loopback on the inbound
    39  	loopback bool
    40  	// mode of the connection
    41  	mode Mode
    42  	// the dial timeout
    43  	dialTimeout time.Duration
    44  	// the read timeout
    45  	readTimeout time.Duration
    46  	// the link on which this message was received
    47  	link string
    48  	// the error response
    49  	errChan chan error
    50  }
    51  
    52  // message is sent over the send channel
    53  type message struct {
    54  	// type of message
    55  	typ string
    56  	// tunnel id
    57  	tunnel string
    58  	// channel name
    59  	channel string
    60  	// the session id
    61  	session string
    62  	// outbound marks the message as outbound
    63  	outbound bool
    64  	// loopback marks the message intended for loopback
    65  	loopback bool
    66  	// mode of the connection
    67  	mode Mode
    68  	// the link to send the message on
    69  	link string
    70  	// transport data
    71  	data *transport.Message
    72  	// the error channel
    73  	errChan chan error
    74  }
    75  
    76  func (s *session) Remote() string {
    77  	return s.remote
    78  }
    79  
    80  func (s *session) Local() string {
    81  	return s.local
    82  }
    83  
    84  func (s *session) Link() string {
    85  	return s.link
    86  }
    87  
    88  func (s *session) Id() string {
    89  	return s.session
    90  }
    91  
    92  func (s *session) Channel() string {
    93  	return s.channel
    94  }
    95  
    96  // newMessage creates a new message based on the session
    97  func (s *session) newMessage(typ string) *message {
    98  	return &message{
    99  		typ:      typ,
   100  		tunnel:   s.tunnel,
   101  		channel:  s.channel,
   102  		session:  s.session,
   103  		outbound: s.outbound,
   104  		loopback: s.loopback,
   105  		mode:     s.mode,
   106  		link:     s.link,
   107  		errChan:  s.errChan,
   108  	}
   109  }
   110  
   111  func (s *session) sendMsg(msg *message) error {
   112  	select {
   113  	case <-s.closed:
   114  		return io.EOF
   115  	case s.send <- msg:
   116  		return nil
   117  	}
   118  }
   119  
   120  func (s *session) wait(msg *message) error {
   121  	// wait for an error response
   122  	select {
   123  	case err := <-msg.errChan:
   124  		if err != nil {
   125  			return err
   126  		}
   127  	case <-s.closed:
   128  		return io.EOF
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  // waitFor waits for the message type required until the timeout specified
   135  func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
   136  	now := time.Now()
   137  
   138  	after := func(timeout time.Duration) <-chan time.Time {
   139  		if timeout < time.Duration(0) {
   140  			return nil
   141  		}
   142  
   143  		// get the delta
   144  		d := time.Since(now)
   145  
   146  		// dial timeout minus time since
   147  		wait := timeout - d
   148  
   149  		if wait < time.Duration(0) {
   150  			wait = time.Duration(0)
   151  		}
   152  
   153  		return time.After(wait)
   154  	}
   155  
   156  	// wait for the message type
   157  	for {
   158  		select {
   159  		case msg := <-s.recv:
   160  			// there may be no message type
   161  			if len(msgType) == 0 {
   162  				return msg, nil
   163  			}
   164  
   165  			// ignore what we don't want
   166  			if msg.typ != msgType {
   167  				log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
   168  				continue
   169  			}
   170  
   171  			// got the message
   172  			return msg, nil
   173  		case <-after(timeout):
   174  			return nil, ErrReadTimeout
   175  		case <-s.closed:
   176  			return nil, io.EOF
   177  		}
   178  	}
   179  }
   180  
   181  // Discover attempts to discover the link for a specific channel.
   182  // This is only used by the tunnel.Dial when first connecting.
   183  func (s *session) Discover() error {
   184  	// create a new discovery message for this channel
   185  	msg := s.newMessage("discover")
   186  	// broadcast the message to all links
   187  	msg.mode = Broadcast
   188  	// its an outbound connection since we're dialling
   189  	msg.outbound = true
   190  	// don't set the link since we don't know where it is
   191  	msg.link = ""
   192  
   193  	// if multicast then set that as session
   194  	if s.mode == Multicast {
   195  		msg.session = "multicast"
   196  	}
   197  
   198  	// send discover message
   199  	if err := s.sendMsg(msg); err != nil {
   200  		return err
   201  	}
   202  
   203  	// set time now
   204  	now := time.Now()
   205  
   206  	// after strips down the dial timeout
   207  	after := func() time.Duration {
   208  		d := time.Since(now)
   209  		// dial timeout minus time since
   210  		wait := s.dialTimeout - d
   211  		// make sure its always > 0
   212  		if wait < time.Duration(0) {
   213  			return time.Duration(0)
   214  		}
   215  		return wait
   216  	}
   217  
   218  	// the discover message is sent out, now
   219  	// wait to hear back about the sent message
   220  	select {
   221  	case <-time.After(after()):
   222  		return ErrDialTimeout
   223  	case err := <-s.errChan:
   224  		if err != nil {
   225  			return err
   226  		}
   227  	}
   228  
   229  	// bail early if its not unicast
   230  	// we don't need to wait for the announce
   231  	if s.mode != Unicast {
   232  		s.discovered = true
   233  		s.accepted = true
   234  		return nil
   235  	}
   236  
   237  	// wait for announce
   238  	_, err := s.waitFor("announce", after())
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	// set discovered
   244  	s.discovered = true
   245  
   246  	return nil
   247  }
   248  
   249  // Open will fire the open message for the session. This is called by the dialler.
   250  // This is to indicate that we want to create a new session.
   251  func (s *session) Open() error {
   252  	// create a new message
   253  	msg := s.newMessage("open")
   254  
   255  	// send open message
   256  	if err := s.sendMsg(msg); err != nil {
   257  		return err
   258  	}
   259  
   260  	// wait for an error response for send
   261  	if err := s.wait(msg); err != nil {
   262  		return err
   263  	}
   264  
   265  	// now wait for the accept message to be returned
   266  	msg, err := s.waitFor("accept", s.dialTimeout)
   267  	if err != nil {
   268  		return err
   269  	}
   270  
   271  	// set to accepted
   272  	s.accepted = true
   273  	// set link
   274  	s.link = msg.link
   275  
   276  	return nil
   277  }
   278  
   279  // Accept sends the accept response to an open message from a dialled connection
   280  func (s *session) Accept() error {
   281  	msg := s.newMessage("accept")
   282  
   283  	// send the accept message
   284  	if err := s.sendMsg(msg); err != nil {
   285  		return err
   286  	}
   287  
   288  	// wait for send response
   289  	return s.wait(msg)
   290  }
   291  
   292  // Announce sends an announcement to notify that this session exists.
   293  // This is primarily used by the listener.
   294  func (s *session) Announce() error {
   295  	msg := s.newMessage("announce")
   296  	// we don't need an error back
   297  	msg.errChan = nil
   298  	// announce to all
   299  	msg.mode = Broadcast
   300  	// we don't need the link
   301  	msg.link = ""
   302  
   303  	// send announce message
   304  	return s.sendMsg(msg)
   305  }
   306  
   307  // Send is used to send a message
   308  func (s *session) Send(m *transport.Message) error {
   309  	// encrypt the transport message payload
   310  	body, err := Encrypt(m.Body, s.token+s.channel+s.session)
   311  	if err != nil {
   312  		log.Debugf("failed to encrypt message body: %v", err)
   313  		return err
   314  	}
   315  
   316  	// make copy
   317  	data := &transport.Message{
   318  		Header: make(map[string]string),
   319  		Body:   body,
   320  	}
   321  
   322  	// encrypt all the headers
   323  	for k, v := range m.Header {
   324  		// encrypt the transport message payload
   325  		val, err := Encrypt([]byte(v), s.token+s.channel+s.session)
   326  		if err != nil {
   327  			log.Debugf("failed to encrypt message header %s: %v", k, err)
   328  			return err
   329  		}
   330  		// hex encode the encrypted header value
   331  		data.Header[k] = hex.EncodeToString(val)
   332  	}
   333  
   334  	// create a new message
   335  	msg := s.newMessage("session")
   336  	// set the data
   337  	msg.data = data
   338  
   339  	// if multicast don't set the link
   340  	if s.mode != Unicast {
   341  		msg.link = ""
   342  	}
   343  
   344  	log.Tracef("Appending %+v to send backlog", msg)
   345  
   346  	// send the actual message
   347  	if err := s.sendMsg(msg); err != nil {
   348  		return err
   349  	}
   350  
   351  	// wait for an error response
   352  	return s.wait(msg)
   353  }
   354  
   355  // Recv is used to receive a message
   356  func (s *session) Recv(m *transport.Message) error {
   357  	var msg *message
   358  
   359  	msg, err := s.waitFor("", s.readTimeout)
   360  	if err != nil {
   361  		return err
   362  	}
   363  
   364  	// check the error if one exists
   365  	select {
   366  	case err := <-msg.errChan:
   367  		return err
   368  	default:
   369  	}
   370  
   371  	//log.Tracef("Received %+v from recv backlog", msg)
   372  	log.Tracef("Received %+v from recv backlog", msg)
   373  
   374  	// decrypt the received payload using the token
   375  	// we have to used msg.session because multicast has a shared
   376  	// session id of "multicast" in this session struct on
   377  	// the listener side
   378  	body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session)
   379  	if err != nil {
   380  		log.Debugf("failed to decrypt message body: %v", err)
   381  		return err
   382  	}
   383  	msg.data.Body = body
   384  
   385  	// encrypt all the headers
   386  	for k, v := range msg.data.Header {
   387  		// hex decode the header values
   388  		h, err := hex.DecodeString(v)
   389  		if err != nil {
   390  			log.Debugf("failed to decode message header %s: %v", k, err)
   391  			return err
   392  		}
   393  		// encrypt the transport message payload
   394  		val, err := Decrypt([]byte(h), s.token+s.channel+msg.session)
   395  		if err != nil {
   396  			log.Debugf("failed to decrypt message header %s: %v", k, err)
   397  			return err
   398  		}
   399  		// hex encode the encrypted header value
   400  		msg.data.Header[k] = string(val)
   401  	}
   402  
   403  	// set the link
   404  	// TODO: decruft, this is only for multicast
   405  	// since the session is now a single session
   406  	// likely provide as part of message.Link()
   407  	msg.data.Header["Micro-Link"] = msg.link
   408  
   409  	// set message
   410  	*m = *msg.data
   411  	// return nil
   412  	return nil
   413  }
   414  
   415  // Close closes the session by sending a close message
   416  func (s *session) Close() error {
   417  	select {
   418  	case <-s.closed:
   419  		// no op
   420  	default:
   421  		close(s.closed)
   422  
   423  		// don't send close on multicast or broadcast
   424  		if s.mode != Unicast {
   425  			return nil
   426  		}
   427  
   428  		// append to backlog
   429  		msg := s.newMessage("close")
   430  		// no error response on close
   431  		msg.errChan = nil
   432  
   433  		// send the close message
   434  		select {
   435  		case s.send <- msg:
   436  		case <-time.After(time.Millisecond * 10):
   437  		}
   438  	}
   439  
   440  	return nil
   441  }