github.com/annwntech/go-micro/v2@v2.9.5/tunnel/session.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/base32"
     6  	"io"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/annwntech/go-micro/v2/logger"
    11  	"github.com/annwntech/go-micro/v2/transport"
    12  )
    13  
    14  // session is our pseudo session for transport.Socket
    15  type session struct {
    16  	// the tunnel id
    17  	tunnel string
    18  	// the channel name
    19  	channel string
    20  	// the session id based on Micro.Tunnel-Session
    21  	session string
    22  	// token is the session token
    23  	token string
    24  	// closed
    25  	closed chan bool
    26  	// remote addr
    27  	remote string
    28  	// local addr
    29  	local string
    30  	// send chan
    31  	send chan *message
    32  	// recv chan
    33  	recv chan *message
    34  	// if the discovery worked
    35  	discovered bool
    36  	// if the session was accepted
    37  	accepted bool
    38  	// outbound marks the session as outbound dialled connection
    39  	outbound bool
    40  	// lookback marks the session as a loopback on the inbound
    41  	loopback bool
    42  	// mode of the connection
    43  	mode Mode
    44  	// the dial timeout
    45  	dialTimeout time.Duration
    46  	// the read timeout
    47  	readTimeout time.Duration
    48  	// the link on which this message was received
    49  	link string
    50  	// the error response
    51  	errChan chan error
    52  	// key for session encryption
    53  	key []byte
    54  	// cipher for session
    55  	gcm cipher.AEAD
    56  	sync.RWMutex
    57  }
    58  
    59  // message is sent over the send channel
    60  type message struct {
    61  	// type of message
    62  	typ string
    63  	// tunnel id
    64  	tunnel string
    65  	// channel name
    66  	channel string
    67  	// the session id
    68  	session string
    69  	// outbound marks the message as outbound
    70  	outbound bool
    71  	// loopback marks the message intended for loopback
    72  	loopback bool
    73  	// mode of the connection
    74  	mode Mode
    75  	// the link to send the message on
    76  	link string
    77  	// transport data
    78  	data *transport.Message
    79  	// the error channel
    80  	errChan chan error
    81  }
    82  
    83  func (s *session) Remote() string {
    84  	return s.remote
    85  }
    86  
    87  func (s *session) Local() string {
    88  	return s.local
    89  }
    90  
    91  func (s *session) Link() string {
    92  	return s.link
    93  }
    94  
    95  func (s *session) Id() string {
    96  	return s.session
    97  }
    98  
    99  func (s *session) Channel() string {
   100  	return s.channel
   101  }
   102  
   103  // newMessage creates a new message based on the session
   104  func (s *session) newMessage(typ string) *message {
   105  	return &message{
   106  		typ:      typ,
   107  		tunnel:   s.tunnel,
   108  		channel:  s.channel,
   109  		session:  s.session,
   110  		outbound: s.outbound,
   111  		loopback: s.loopback,
   112  		mode:     s.mode,
   113  		link:     s.link,
   114  		errChan:  s.errChan,
   115  	}
   116  }
   117  
   118  func (s *session) sendMsg(msg *message) error {
   119  	select {
   120  	case <-s.closed:
   121  		return io.EOF
   122  	case s.send <- msg:
   123  		return nil
   124  	}
   125  }
   126  
   127  func (s *session) wait(msg *message) error {
   128  	// wait for an error response
   129  	select {
   130  	case err := <-msg.errChan:
   131  		if err != nil {
   132  			return err
   133  		}
   134  	case <-s.closed:
   135  		return io.EOF
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  // waitFor waits for the message type required until the timeout specified
   142  func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
   143  	now := time.Now()
   144  
   145  	after := func(timeout time.Duration) <-chan time.Time {
   146  		if timeout < time.Duration(0) {
   147  			return nil
   148  		}
   149  
   150  		// get the delta
   151  		d := time.Since(now)
   152  
   153  		// dial timeout minus time since
   154  		wait := timeout - d
   155  
   156  		if wait < time.Duration(0) {
   157  			wait = time.Duration(0)
   158  		}
   159  
   160  		return time.After(wait)
   161  	}
   162  
   163  	// wait for the message type
   164  	for {
   165  		select {
   166  		case msg := <-s.recv:
   167  			// there may be no message type
   168  			if len(msgType) == 0 {
   169  				return msg, nil
   170  			}
   171  
   172  			// ignore what we don't want
   173  			if msg.typ != msgType {
   174  				if logger.V(logger.DebugLevel, log) {
   175  					log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
   176  				}
   177  				continue
   178  			}
   179  
   180  			// got the message
   181  			return msg, nil
   182  		case <-after(timeout):
   183  			return nil, ErrReadTimeout
   184  		case <-s.closed:
   185  			// check pending message queue
   186  			select {
   187  			case msg := <-s.recv:
   188  				// there may be no message type
   189  				if len(msgType) == 0 {
   190  					return msg, nil
   191  				}
   192  
   193  				// ignore what we don't want
   194  				if msg.typ != msgType {
   195  					if logger.V(logger.DebugLevel, log) {
   196  						log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
   197  					}
   198  					continue
   199  				}
   200  
   201  				// got the message
   202  				return msg, nil
   203  			default:
   204  				// non blocking
   205  			}
   206  			return nil, io.EOF
   207  		}
   208  	}
   209  }
   210  
   211  // Discover attempts to discover the link for a specific channel.
   212  // This is only used by the tunnel.Dial when first connecting.
   213  func (s *session) Discover() error {
   214  	// create a new discovery message for this channel
   215  	msg := s.newMessage("discover")
   216  	// broadcast the message to all links
   217  	msg.mode = Broadcast
   218  	// its an outbound connection since we're dialling
   219  	msg.outbound = true
   220  	// don't set the link since we don't know where it is
   221  	msg.link = ""
   222  
   223  	// if multicast then set that as session
   224  	if s.mode == Multicast {
   225  		msg.session = "multicast"
   226  	}
   227  
   228  	// send discover message
   229  	if err := s.sendMsg(msg); err != nil {
   230  		return err
   231  	}
   232  
   233  	// set time now
   234  	now := time.Now()
   235  
   236  	// after strips down the dial timeout
   237  	after := func() time.Duration {
   238  		d := time.Since(now)
   239  		// dial timeout minus time since
   240  		wait := s.dialTimeout - d
   241  		// make sure its always > 0
   242  		if wait < time.Duration(0) {
   243  			return time.Duration(0)
   244  		}
   245  		return wait
   246  	}
   247  
   248  	// the discover message is sent out, now
   249  	// wait to hear back about the sent message
   250  	select {
   251  	case <-time.After(after()):
   252  		return ErrDialTimeout
   253  	case err := <-s.errChan:
   254  		if err != nil {
   255  			return err
   256  		}
   257  	}
   258  
   259  	// bail early if its not unicast
   260  	// we don't need to wait for the announce
   261  	if s.mode != Unicast {
   262  		s.discovered = true
   263  		s.accepted = true
   264  		return nil
   265  	}
   266  
   267  	// wait for announce
   268  	_, err := s.waitFor("announce", after())
   269  	if err != nil {
   270  		return err
   271  	}
   272  
   273  	// set discovered
   274  	s.discovered = true
   275  
   276  	return nil
   277  }
   278  
   279  // Open will fire the open message for the session. This is called by the dialler.
   280  // This is to indicate that we want to create a new session.
   281  func (s *session) Open() error {
   282  	// create a new message
   283  	msg := s.newMessage("open")
   284  
   285  	// send open message
   286  	if err := s.sendMsg(msg); err != nil {
   287  		return err
   288  	}
   289  
   290  	// wait for an error response for send
   291  	if err := s.wait(msg); err != nil {
   292  		return err
   293  	}
   294  
   295  	// now wait for the accept message to be returned
   296  	msg, err := s.waitFor("accept", s.dialTimeout)
   297  	if err != nil {
   298  		return err
   299  	}
   300  
   301  	// set to accepted
   302  	s.accepted = true
   303  	// set link
   304  	s.link = msg.link
   305  
   306  	return nil
   307  }
   308  
   309  // Accept sends the accept response to an open message from a dialled connection
   310  func (s *session) Accept() error {
   311  	msg := s.newMessage("accept")
   312  
   313  	// send the accept message
   314  	if err := s.sendMsg(msg); err != nil {
   315  		return err
   316  	}
   317  
   318  	// wait for send response
   319  	return s.wait(msg)
   320  }
   321  
   322  // Announce sends an announcement to notify that this session exists.
   323  // This is primarily used by the listener.
   324  func (s *session) Announce() error {
   325  	msg := s.newMessage("announce")
   326  	// we don't need an error back
   327  	msg.errChan = nil
   328  	// announce to all
   329  	msg.mode = Broadcast
   330  	// we don't need the link
   331  	msg.link = ""
   332  
   333  	// send announce message
   334  	return s.sendMsg(msg)
   335  }
   336  
   337  // Send is used to send a message
   338  func (s *session) Send(m *transport.Message) error {
   339  	var err error
   340  
   341  	s.RLock()
   342  	gcm := s.gcm
   343  	s.RUnlock()
   344  
   345  	if gcm == nil {
   346  		gcm, err = newCipher(s.key)
   347  		if err != nil {
   348  			return err
   349  		}
   350  		s.Lock()
   351  		s.gcm = gcm
   352  		s.Unlock()
   353  	}
   354  	// encrypt the transport message payload
   355  	body, err := Encrypt(gcm, m.Body)
   356  	if err != nil {
   357  		log.Debugf("failed to encrypt message body: %v", err)
   358  		return err
   359  	}
   360  
   361  	// make copy, without rehash and realloc
   362  	data := &transport.Message{
   363  		Header: make(map[string]string, len(m.Header)),
   364  		Body:   body,
   365  	}
   366  
   367  	// encrypt all the headers
   368  	for k, v := range m.Header {
   369  		// encrypt the transport message payload
   370  		val, err := Encrypt(s.gcm, []byte(v))
   371  		if err != nil {
   372  			log.Debugf("failed to encrypt message header %s: %v", k, err)
   373  			return err
   374  		}
   375  		// add the encrypted header value
   376  		data.Header[k] = base32.StdEncoding.EncodeToString(val)
   377  	}
   378  
   379  	// create a new message
   380  	msg := s.newMessage("session")
   381  	// set the data
   382  	msg.data = data
   383  
   384  	// if multicast don't set the link
   385  	if s.mode != Unicast {
   386  		msg.link = ""
   387  	}
   388  
   389  	if logger.V(logger.TraceLevel, log) {
   390  		log.Tracef("Appending to send backlog: %v", msg)
   391  	}
   392  	// send the actual message
   393  	if err := s.sendMsg(msg); err != nil {
   394  		return err
   395  	}
   396  
   397  	// wait for an error response
   398  	return s.wait(msg)
   399  }
   400  
   401  // Recv is used to receive a message
   402  func (s *session) Recv(m *transport.Message) error {
   403  	var msg *message
   404  
   405  	msg, err := s.waitFor("", s.readTimeout)
   406  	if err != nil {
   407  		return err
   408  	}
   409  
   410  	// check the error if one exists
   411  	select {
   412  	case err := <-msg.errChan:
   413  		return err
   414  	default:
   415  	}
   416  
   417  	if logger.V(logger.TraceLevel, log) {
   418  		log.Tracef("Received from recv backlog: %v", msg)
   419  	}
   420  
   421  	gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
   422  	if err != nil {
   423  		if logger.V(logger.ErrorLevel, log) {
   424  			log.Errorf("unable to create cipher: %v", err)
   425  		}
   426  		return err
   427  	}
   428  
   429  	// decrypt the received payload using the token
   430  	// we have to used msg.session because multicast has a shared
   431  	// session id of "multicast" in this session struct on
   432  	// the listener side
   433  	msg.data.Body, err = Decrypt(gcm, msg.data.Body)
   434  	if err != nil {
   435  		if logger.V(logger.DebugLevel, log) {
   436  			log.Debugf("failed to decrypt message body: %v", err)
   437  		}
   438  		return err
   439  	}
   440  
   441  	// dencrypt all the headers
   442  	for k, v := range msg.data.Header {
   443  		// decode the header values
   444  		h, err := base32.StdEncoding.DecodeString(v)
   445  		if err != nil {
   446  			if logger.V(logger.DebugLevel, log) {
   447  				log.Debugf("failed to decode message header %s: %v", k, err)
   448  			}
   449  			return err
   450  		}
   451  
   452  		// dencrypt the transport message payload
   453  		val, err := Decrypt(gcm, h)
   454  		if err != nil {
   455  			if logger.V(logger.DebugLevel, log) {
   456  				log.Debugf("failed to decrypt message header %s: %v", k, err)
   457  			}
   458  			return err
   459  		}
   460  		// add decrypted header value
   461  		msg.data.Header[k] = string(val)
   462  	}
   463  
   464  	// set the link
   465  	// TODO: decruft, this is only for multicast
   466  	// since the session is now a single session
   467  	// likely provide as part of message.Link()
   468  	msg.data.Header["Micro-Link"] = msg.link
   469  
   470  	// set message
   471  	*m = *msg.data
   472  	// return nil
   473  	return nil
   474  }
   475  
   476  // Close closes the session by sending a close message
   477  func (s *session) Close() error {
   478  	select {
   479  	case <-s.closed:
   480  		// no op
   481  	default:
   482  		close(s.closed)
   483  
   484  		// don't send close on multicast or broadcast
   485  		if s.mode != Unicast {
   486  			return nil
   487  		}
   488  
   489  		// append to backlog
   490  		msg := s.newMessage("close")
   491  		// no error response on close
   492  		msg.errChan = nil
   493  
   494  		// send the close message
   495  		select {
   496  		case s.send <- msg:
   497  		case <-time.After(time.Millisecond * 10):
   498  		}
   499  	}
   500  
   501  	return nil
   502  }