github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/service/network/tunnel/mucp/session.go (about)

     1  // Licensed under the Apache License, Version 2.0 (the "License");
     2  // you may not use this file except in compliance with the License.
     3  // You may obtain a copy of the License at
     4  //
     5  //     https://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Unless required by applicable law or agreed to in writing, software
     8  // distributed under the License is distributed on an "AS IS" BASIS,
     9  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  // See the License for the specific language governing permissions and
    11  // limitations under the License.
    12  //
    13  // Original source: github.com/micro/go-micro/v3/network/tunnel/mucp/session.go
    14  
    15  package mucp
    16  
    17  import (
    18  	"crypto/cipher"
    19  	"encoding/base32"
    20  	"io"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/tickoalcantara12/micro/v3/service/logger"
    25  	"github.com/tickoalcantara12/micro/v3/service/network/transport"
    26  	"github.com/tickoalcantara12/micro/v3/service/network/tunnel"
    27  )
    28  
    29  // session is our pseudo session for transport.Socket
    30  type session struct {
    31  	// the tunnel id
    32  	tunnel string
    33  	// the channel name
    34  	channel string
    35  	// the session id based on Micro.Tunnel-Session
    36  	session string
    37  	// token is the session token
    38  	token string
    39  	// closed
    40  	closed chan bool
    41  	// remote addr
    42  	remote string
    43  	// local addr
    44  	local string
    45  	// send chan
    46  	send chan *message
    47  	// recv chan
    48  	recv chan *message
    49  	// if the discovery worked
    50  	discovered bool
    51  	// if the session was accepted
    52  	accepted bool
    53  	// outbound marks the session as outbound dialled connection
    54  	outbound bool
    55  	// lookback marks the session as a loopback on the inbound
    56  	loopback bool
    57  	// mode of the connection
    58  	mode tunnel.Mode
    59  	// the dial timeout
    60  	dialTimeout time.Duration
    61  	// the read timeout
    62  	readTimeout time.Duration
    63  	// the link on which this message was received
    64  	link string
    65  	// the error response
    66  	errChan chan error
    67  	// key for session encryption
    68  	key []byte
    69  	// cipher for session
    70  	gcm cipher.AEAD
    71  	sync.RWMutex
    72  }
    73  
    74  // message is sent over the send channel
    75  type message struct {
    76  	// type of message
    77  	typ string
    78  	// tunnel id
    79  	tunnel string
    80  	// channel name
    81  	channel string
    82  	// the session id
    83  	session string
    84  	// outbound marks the message as outbound
    85  	outbound bool
    86  	// loopback marks the message intended for loopback
    87  	loopback bool
    88  	// mode of the connection
    89  	mode tunnel.Mode
    90  	// the link to send the message on
    91  	link string
    92  	// transport data
    93  	data *transport.Message
    94  	// the error channel
    95  	errChan chan error
    96  }
    97  
    98  func (s *session) Remote() string {
    99  	return s.remote
   100  }
   101  
   102  func (s *session) Local() string {
   103  	return s.local
   104  }
   105  
   106  func (s *session) Link() string {
   107  	return s.link
   108  }
   109  
   110  func (s *session) Id() string {
   111  	return s.session
   112  }
   113  
   114  func (s *session) Channel() string {
   115  	return s.channel
   116  }
   117  
   118  // newMessage creates a new message based on the session
   119  func (s *session) newMessage(typ string) *message {
   120  	return &message{
   121  		typ:      typ,
   122  		tunnel:   s.tunnel,
   123  		channel:  s.channel,
   124  		session:  s.session,
   125  		outbound: s.outbound,
   126  		loopback: s.loopback,
   127  		mode:     s.mode,
   128  		link:     s.link,
   129  		errChan:  s.errChan,
   130  	}
   131  }
   132  
   133  func (s *session) sendMsg(msg *message) error {
   134  	select {
   135  	case <-s.closed:
   136  		return io.EOF
   137  	case s.send <- msg:
   138  		return nil
   139  	}
   140  }
   141  
   142  func (s *session) wait(msg *message) error {
   143  	// wait for an error response
   144  	select {
   145  	case err := <-msg.errChan:
   146  		if err != nil {
   147  			return err
   148  		}
   149  	case <-s.closed:
   150  		return io.EOF
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  // waitFor waits for the message type required until the timeout specified
   157  func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
   158  	now := time.Now()
   159  
   160  	after := func(timeout time.Duration) <-chan time.Time {
   161  		if timeout < time.Duration(0) {
   162  			return nil
   163  		}
   164  
   165  		// get the delta
   166  		d := time.Since(now)
   167  
   168  		// dial timeout minus time since
   169  		wait := timeout - d
   170  
   171  		if wait < time.Duration(0) {
   172  			wait = time.Duration(0)
   173  		}
   174  
   175  		return time.After(wait)
   176  	}
   177  
   178  	// wait for the message type
   179  	for {
   180  		select {
   181  		case msg := <-s.recv:
   182  			// there may be no message type
   183  			if len(msgType) == 0 {
   184  				return msg, nil
   185  			}
   186  
   187  			// ignore what we don't want
   188  			if msg.typ != msgType {
   189  				if logger.V(logger.DebugLevel, log) {
   190  					log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
   191  				}
   192  				continue
   193  			}
   194  
   195  			// got the message
   196  			return msg, nil
   197  		case <-after(timeout):
   198  			return nil, tunnel.ErrReadTimeout
   199  		case <-s.closed:
   200  			// check pending message queue
   201  			select {
   202  			case msg := <-s.recv:
   203  				// there may be no message type
   204  				if len(msgType) == 0 {
   205  					return msg, nil
   206  				}
   207  
   208  				// ignore what we don't want
   209  				if msg.typ != msgType {
   210  					if logger.V(logger.DebugLevel, log) {
   211  						log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
   212  					}
   213  					continue
   214  				}
   215  
   216  				// got the message
   217  				return msg, nil
   218  			default:
   219  				// non blocking
   220  			}
   221  			return nil, io.EOF
   222  		}
   223  	}
   224  }
   225  
   226  // Discover attempts to discover the link for a specific channel.
   227  // This is only used by the tunnel.Dial when first connecting.
   228  func (s *session) Discover() error {
   229  	// create a new discovery message for this channel
   230  	msg := s.newMessage("discover")
   231  	// broadcast the message to all links
   232  	msg.mode = tunnel.Broadcast
   233  	// its an outbound connection since we're dialling
   234  	msg.outbound = true
   235  	// don't set the link since we don't know where it is
   236  	msg.link = ""
   237  
   238  	// if multicast then set that as session
   239  	if s.mode == tunnel.Multicast {
   240  		msg.session = "multicast"
   241  	}
   242  
   243  	// send discover message
   244  	if err := s.sendMsg(msg); err != nil {
   245  		return err
   246  	}
   247  
   248  	// set time now
   249  	now := time.Now()
   250  
   251  	// after strips down the dial timeout
   252  	after := func() time.Duration {
   253  		d := time.Since(now)
   254  		// dial timeout minus time since
   255  		wait := s.dialTimeout - d
   256  		// make sure its always > 0
   257  		if wait < time.Duration(0) {
   258  			return time.Duration(0)
   259  		}
   260  		return wait
   261  	}
   262  
   263  	// the discover message is sent out, now
   264  	// wait to hear back about the sent message
   265  	select {
   266  	case <-time.After(after()):
   267  		return tunnel.ErrDialTimeout
   268  	case err := <-s.errChan:
   269  		if err != nil {
   270  			return err
   271  		}
   272  	}
   273  
   274  	// bail early if its not unicast
   275  	// we don't need to wait for the announce
   276  	if s.mode != tunnel.Unicast {
   277  		s.discovered = true
   278  		s.accepted = true
   279  		return nil
   280  	}
   281  
   282  	// wait for announce
   283  	_, err := s.waitFor("announce", after())
   284  	if err != nil {
   285  		return err
   286  	}
   287  
   288  	// set discovered
   289  	s.discovered = true
   290  
   291  	return nil
   292  }
   293  
   294  // Open will fire the open message for the session. This is called by the dialler.
   295  // This is to indicate that we want to create a new session.
   296  func (s *session) Open() error {
   297  	// create a new message
   298  	msg := s.newMessage("open")
   299  
   300  	// send open message
   301  	if err := s.sendMsg(msg); err != nil {
   302  		return err
   303  	}
   304  
   305  	// wait for an error response for send
   306  	if err := s.wait(msg); err != nil {
   307  		return err
   308  	}
   309  
   310  	// now wait for the accept message to be returned
   311  	msg, err := s.waitFor("accept", s.dialTimeout)
   312  	if err != nil {
   313  		return err
   314  	}
   315  
   316  	// set to accepted
   317  	s.accepted = true
   318  	// set link
   319  	s.link = msg.link
   320  
   321  	return nil
   322  }
   323  
   324  // Accept sends the accept response to an open message from a dialled connection
   325  func (s *session) Accept() error {
   326  	msg := s.newMessage("accept")
   327  
   328  	// send the accept message
   329  	if err := s.sendMsg(msg); err != nil {
   330  		return err
   331  	}
   332  
   333  	// wait for send response
   334  	return s.wait(msg)
   335  }
   336  
   337  // Announce sends an announcement to notify that this session exists.
   338  // This is primarily used by the listener.
   339  func (s *session) Announce() error {
   340  	msg := s.newMessage("announce")
   341  	// we don't need an error back
   342  	msg.errChan = nil
   343  	// announce to all
   344  	msg.mode = tunnel.Broadcast
   345  	// we don't need the link
   346  	msg.link = ""
   347  
   348  	// send announce message
   349  	return s.sendMsg(msg)
   350  }
   351  
   352  // Send is used to send a message
   353  func (s *session) Send(m *transport.Message) error {
   354  	var err error
   355  
   356  	s.RLock()
   357  	gcm := s.gcm
   358  	s.RUnlock()
   359  
   360  	if gcm == nil {
   361  		gcm, err = newCipher(s.key)
   362  		if err != nil {
   363  			return err
   364  		}
   365  		s.Lock()
   366  		s.gcm = gcm
   367  		s.Unlock()
   368  	}
   369  	// encrypt the transport message payload
   370  	body, err := Encrypt(gcm, m.Body)
   371  	if err != nil {
   372  		log.Debugf("failed to encrypt message body: %v", err)
   373  		return err
   374  	}
   375  
   376  	// make copy, without rehash and realloc
   377  	data := &transport.Message{
   378  		Header: make(map[string]string, len(m.Header)),
   379  		Body:   body,
   380  	}
   381  
   382  	// encrypt all the headers
   383  	for k, v := range m.Header {
   384  		// encrypt the transport message payload
   385  		val, err := Encrypt(s.gcm, []byte(v))
   386  		if err != nil {
   387  			log.Debugf("failed to encrypt message header %s: %v", k, err)
   388  			return err
   389  		}
   390  		// add the encrypted header value
   391  		data.Header[k] = base32.StdEncoding.EncodeToString(val)
   392  	}
   393  
   394  	// create a new message
   395  	msg := s.newMessage("session")
   396  	// set the data
   397  	msg.data = data
   398  
   399  	// if multicast don't set the link
   400  	if s.mode != tunnel.Unicast {
   401  		msg.link = ""
   402  	}
   403  
   404  	if logger.V(logger.TraceLevel, log) {
   405  		log.Tracef("Appending to send backlog: %v", msg)
   406  	}
   407  	// send the actual message
   408  	if err := s.sendMsg(msg); err != nil {
   409  		return err
   410  	}
   411  
   412  	// wait for an error response
   413  	return s.wait(msg)
   414  }
   415  
   416  // Recv is used to receive a message
   417  func (s *session) Recv(m *transport.Message) error {
   418  	var msg *message
   419  
   420  	msg, err := s.waitFor("", s.readTimeout)
   421  	if err != nil {
   422  		return err
   423  	}
   424  
   425  	// check the error if one exists
   426  	select {
   427  	case err := <-msg.errChan:
   428  		return err
   429  	default:
   430  	}
   431  
   432  	if logger.V(logger.TraceLevel, log) {
   433  		log.Tracef("Received from recv backlog: %v", msg)
   434  	}
   435  
   436  	gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
   437  	if err != nil {
   438  		if logger.V(logger.ErrorLevel, log) {
   439  			log.Errorf("unable to create cipher: %v", err)
   440  		}
   441  		return err
   442  	}
   443  
   444  	// decrypt the received payload using the token
   445  	// we have to used msg.session because multicast has a shared
   446  	// session id of "multicast" in this session struct on
   447  	// the listener side
   448  	msg.data.Body, err = Decrypt(gcm, msg.data.Body)
   449  	if err != nil {
   450  		if logger.V(logger.DebugLevel, log) {
   451  			log.Debugf("failed to decrypt message body: %v", err)
   452  		}
   453  		return err
   454  	}
   455  
   456  	// dencrypt all the headers
   457  	for k, v := range msg.data.Header {
   458  		// decode the header values
   459  		h, err := base32.StdEncoding.DecodeString(v)
   460  		if err != nil {
   461  			if logger.V(logger.DebugLevel, log) {
   462  				log.Debugf("failed to decode message header %s: %v", k, err)
   463  			}
   464  			return err
   465  		}
   466  
   467  		// dencrypt the transport message payload
   468  		val, err := Decrypt(gcm, h)
   469  		if err != nil {
   470  			if logger.V(logger.DebugLevel, log) {
   471  				log.Debugf("failed to decrypt message header %s: %v", k, err)
   472  			}
   473  			return err
   474  		}
   475  		// add decrypted header value
   476  		msg.data.Header[k] = string(val)
   477  	}
   478  
   479  	// set the link
   480  	// TODO: decruft, this is only for multicast
   481  	// since the session is now a single session
   482  	// likely provide as part of message.Link()
   483  	msg.data.Header["Micro-Link"] = msg.link
   484  
   485  	// set message
   486  	*m = *msg.data
   487  	// return nil
   488  	return nil
   489  }
   490  
   491  // Close closes the session by sending a close message
   492  func (s *session) Close() error {
   493  	select {
   494  	case <-s.closed:
   495  		// no op
   496  	default:
   497  		close(s.closed)
   498  
   499  		// don't send close on multicast or broadcast
   500  		if s.mode != tunnel.Unicast {
   501  			return nil
   502  		}
   503  
   504  		// append to backlog
   505  		msg := s.newMessage("close")
   506  		// no error response on close
   507  		msg.errChan = nil
   508  
   509  		// send the close message
   510  		select {
   511  		case s.send <- msg:
   512  		case <-time.After(time.Millisecond * 10):
   513  		}
   514  	}
   515  
   516  	return nil
   517  }