github.com/vmware/transport-go@v1.3.4/stompserver/stomp_connection.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package stompserver
     5  
     6  import (
     7  	"fmt"
     8  	"github.com/go-stomp/stomp/v3"
     9  	"github.com/go-stomp/stomp/v3/frame"
    10  	"github.com/google/uuid"
    11  	"log"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  )
    18  
    19  type subscription struct {
    20  	id          string
    21  	destination string
    22  }
    23  
    24  type StompConn interface {
    25  	// Return unique connection Id string
    26  	GetId() string
    27  	SendFrameToSubscription(f *frame.Frame, sub *subscription)
    28  	Close()
    29  }
    30  
    31  const (
    32  	maxHeartBeatDuration = time.Duration(999999999) * time.Millisecond
    33  )
    34  
    35  const (
    36  	connecting int32 = iota
    37  	connected
    38  	closed
    39  )
    40  
    41  type stompConn struct {
    42  	rawConnection    RawConnection
    43  	state            int32
    44  	version          stomp.Version
    45  	inFrames         chan *frame.Frame
    46  	outFrames        chan *frame.Frame
    47  	readTimeoutMs    int64
    48  	writeTimeout     time.Duration
    49  	id               string
    50  	events           chan *ConnEvent
    51  	config           StompConfig
    52  	subscriptions    map[string]*subscription
    53  	currentMessageId uint64
    54  	closeOnce        sync.Once
    55  }
    56  
    57  func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *ConnEvent) StompConn {
    58  	conn := &stompConn{
    59  		rawConnection: rawConnection,
    60  		state:         connecting,
    61  		inFrames:      make(chan *frame.Frame, 32),
    62  		outFrames:     make(chan *frame.Frame, 32),
    63  		config:        config,
    64  		id:            uuid.New().String(),
    65  		events:        events,
    66  		subscriptions: make(map[string]*subscription),
    67  	}
    68  
    69  	go conn.run()
    70  	go conn.readInFrames()
    71  
    72  	return conn
    73  }
    74  
    75  func (conn *stompConn) SendFrameToSubscription(f *frame.Frame, sub *subscription) {
    76  	f.Header.Add(frame.Subscription, sub.id)
    77  	conn.outFrames <- f
    78  }
    79  
    80  func (conn *stompConn) Close() {
    81  	conn.closeOnce.Do(func() {
    82  		atomic.StoreInt32(&conn.state, closed)
    83  		conn.rawConnection.Close()
    84  
    85  		conn.events <- &ConnEvent{
    86  			ConnId:    conn.GetId(),
    87  			eventType: ConnectionClosed,
    88  			conn:      conn,
    89  		}
    90  	})
    91  }
    92  
    93  func (conn *stompConn) GetId() string {
    94  	return conn.id
    95  }
    96  
    97  func (conn *stompConn) run() {
    98  	defer conn.Close()
    99  
   100  	var timerChannel <-chan time.Time
   101  	var timer *time.Timer
   102  
   103  	for {
   104  
   105  		if atomic.LoadInt32(&conn.state) == closed {
   106  			return
   107  		}
   108  
   109  		if timer == nil && conn.writeTimeout > 0 {
   110  			timer = time.NewTimer(conn.writeTimeout)
   111  			timerChannel = timer.C
   112  		}
   113  
   114  		select {
   115  		case f, ok := <-conn.outFrames:
   116  			if !ok {
   117  				// close connection
   118  				return
   119  			}
   120  
   121  			// reset heart-beat timer
   122  			if timer != nil {
   123  				timer.Stop()
   124  				timer = nil
   125  			}
   126  
   127  			conn.populateMessageIdHeader(f)
   128  
   129  			// write the frame to the client
   130  			err := conn.rawConnection.WriteFrame(f)
   131  			if err != nil || f.Command == frame.ERROR {
   132  				return
   133  			}
   134  
   135  		case f, ok := <-conn.inFrames:
   136  			if !ok {
   137  				return
   138  			}
   139  
   140  			if err := conn.handleIncomingFrame(f); err != nil {
   141  				conn.sendError(err)
   142  				return
   143  			}
   144  
   145  		case _ = <-timerChannel:
   146  			// write a heart-beat
   147  			err := conn.rawConnection.WriteFrame(nil)
   148  			if err != nil {
   149  				return
   150  			}
   151  			if timer != nil {
   152  				timer.Stop()
   153  				timer = nil
   154  			}
   155  		}
   156  	}
   157  }
   158  
   159  func (conn *stompConn) handleIncomingFrame(f *frame.Frame) error {
   160  	switch f.Command {
   161  
   162  	case frame.CONNECT, frame.STOMP:
   163  		return conn.handleConnect(f)
   164  
   165  	case frame.DISCONNECT:
   166  		return conn.handleDisconnect(f)
   167  
   168  	case frame.SEND:
   169  		return conn.handleSend(f)
   170  
   171  	case frame.SUBSCRIBE:
   172  		return conn.handleSubscribe(f)
   173  
   174  	case frame.UNSUBSCRIBE:
   175  		return conn.handleUnsubscribe(f)
   176  	}
   177  
   178  	return unsupportedStompCommandError
   179  }
   180  
   181  // Returns true if the frame contains ANY of the specified
   182  // headers
   183  func containsHeader(f *frame.Frame, headers ...string) bool {
   184  	for _, h := range headers {
   185  		if _, ok := f.Header.Contains(h); ok {
   186  			return true
   187  		}
   188  	}
   189  	return false
   190  }
   191  
   192  func (conn *stompConn) handleConnect(f *frame.Frame) error {
   193  	if atomic.LoadInt32(&conn.state) == connected {
   194  		return unexpectedStompCommandError
   195  	}
   196  
   197  	if containsHeader(f, frame.Receipt) {
   198  		return invalidHeaderError
   199  	}
   200  
   201  	var err error
   202  	conn.version, err = determineVersion(f)
   203  	if err != nil {
   204  		log.Println("cannot determine version")
   205  		return err
   206  	}
   207  
   208  	if conn.version == stomp.V10 {
   209  		return unsupportedStompVersionError
   210  	}
   211  
   212  	cxDuration, cyDuration, err := getHeartBeat(f)
   213  	if err != nil {
   214  		log.Println("invalid heart-beat")
   215  		return err
   216  	}
   217  
   218  	min := time.Duration(conn.config.HeartBeat()) * time.Millisecond
   219  	if min > maxHeartBeatDuration {
   220  		min = maxHeartBeatDuration
   221  	}
   222  
   223  	// apply a minimum heartbeat
   224  	if cxDuration > 0 {
   225  		if min == 0 || cxDuration < min {
   226  			cxDuration = min
   227  		}
   228  	}
   229  	if cyDuration > 0 {
   230  		if min == 0 || cyDuration < min {
   231  			cyDuration = min
   232  		}
   233  	}
   234  
   235  	conn.writeTimeout = cyDuration
   236  
   237  	cx, cy := int64(cxDuration/time.Millisecond), int64(cyDuration/time.Millisecond)
   238  	atomic.StoreInt64(&conn.readTimeoutMs, cx)
   239  
   240  	response := frame.New(frame.CONNECTED,
   241  		frame.Version, string(conn.version),
   242  		frame.Server, "stompServer/0.0.1",
   243  		frame.HeartBeat, fmt.Sprintf("%d,%d", cy, cx))
   244  
   245  	err = conn.rawConnection.WriteFrame(response)
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	atomic.StoreInt32(&conn.state, connected)
   251  
   252  	conn.events <- &ConnEvent{
   253  		ConnId:    conn.GetId(),
   254  		eventType: ConnectionEstablished,
   255  		conn:      conn,
   256  	}
   257  
   258  	return nil
   259  }
   260  
   261  func (conn *stompConn) handleDisconnect(f *frame.Frame) error {
   262  	if atomic.LoadInt32(&conn.state) == connecting {
   263  		return notConnectedStompError
   264  	}
   265  
   266  	conn.sendReceiptResponse(f)
   267  	conn.Close()
   268  
   269  	return nil
   270  }
   271  
   272  func (conn *stompConn) handleSubscribe(f *frame.Frame) error {
   273  	switch atomic.LoadInt32(&conn.state) {
   274  	case connecting:
   275  		return notConnectedStompError
   276  	case closed:
   277  		return nil
   278  	}
   279  
   280  	subId, ok := f.Header.Contains(frame.Id)
   281  	if !ok {
   282  		return invalidSubscriptionError
   283  	}
   284  
   285  	dest, ok := f.Header.Contains(frame.Destination)
   286  	if !ok {
   287  		return invalidFrameError
   288  	}
   289  
   290  	if _, exists := conn.subscriptions[subId]; exists {
   291  		// subscription already exists
   292  		return nil
   293  	}
   294  
   295  	conn.subscriptions[subId] = &subscription{
   296  		id:          subId,
   297  		destination: dest,
   298  	}
   299  
   300  	conn.events <- &ConnEvent{
   301  		ConnId:      conn.GetId(),
   302  		eventType:   SubscribeToTopic,
   303  		destination: dest,
   304  		conn:        conn,
   305  		sub:         conn.subscriptions[subId],
   306  		frame:       f,
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  func (conn *stompConn) handleUnsubscribe(f *frame.Frame) error {
   313  	switch atomic.LoadInt32(&conn.state) {
   314  	case connecting:
   315  		return notConnectedStompError
   316  	case closed:
   317  		return nil
   318  	}
   319  
   320  	id, ok := f.Header.Contains(frame.Id)
   321  	if !ok {
   322  		return invalidSubscriptionError
   323  	}
   324  
   325  	conn.sendReceiptResponse(f)
   326  
   327  	sub, ok := conn.subscriptions[id]
   328  	if !ok {
   329  		// subscription already removed
   330  		return nil
   331  	}
   332  
   333  	// remove the subscription
   334  	delete(conn.subscriptions, id)
   335  
   336  	conn.events <- &ConnEvent{
   337  		ConnId:      conn.GetId(),
   338  		eventType:   UnsubscribeFromTopic,
   339  		conn:        conn,
   340  		sub:         sub,
   341  		destination: sub.destination,
   342  	}
   343  
   344  	return nil
   345  }
   346  
   347  func (conn *stompConn) handleSend(f *frame.Frame) error {
   348  	switch atomic.LoadInt32(&conn.state) {
   349  	case connecting:
   350  		return notConnectedStompError
   351  	case closed:
   352  		return nil
   353  	}
   354  
   355  	// TODO: Remove if we start supporting transactions
   356  	if containsHeader(f, frame.Transaction) {
   357  		return unsupportedStompCommandError
   358  	}
   359  
   360  	// no destination triggers an error
   361  	dest, ok := f.Header.Contains(frame.Destination)
   362  	if !ok {
   363  		return invalidFrameError
   364  	}
   365  
   366  	// reject SENDing directly to non-request channels by clients
   367  	if !conn.config.IsAppRequestDestination(f.Header.Get(frame.Destination)) {
   368  		return invalidSendDestinationError
   369  	}
   370  
   371  	err := conn.sendReceiptResponse(f)
   372  	if err != nil {
   373  		return err
   374  	}
   375  
   376  	f.Command = frame.MESSAGE
   377  	conn.events <- &ConnEvent{
   378  		ConnId:      conn.GetId(),
   379  		eventType:   IncomingMessage,
   380  		destination: dest,
   381  		frame:       f,
   382  		conn:        conn,
   383  	}
   384  
   385  	return nil
   386  }
   387  
   388  func (conn *stompConn) sendReceiptResponse(f *frame.Frame) error {
   389  	if receipt, ok := f.Header.Contains(frame.Receipt); ok {
   390  		f.Header.Del(frame.Receipt)
   391  		return conn.rawConnection.WriteFrame(frame.New(frame.RECEIPT, frame.ReceiptId, receipt))
   392  	}
   393  	return nil
   394  }
   395  
   396  func (conn *stompConn) readInFrames() {
   397  	defer func() {
   398  		close(conn.inFrames)
   399  	}()
   400  
   401  	infiniteTimeout := time.Time{}
   402  	var readTimeoutMs int64 = 0
   403  	for {
   404  		readTimeoutMs = atomic.LoadInt64(&conn.readTimeoutMs)
   405  		if readTimeoutMs > 0 {
   406  			conn.rawConnection.SetReadDeadline(time.Now().Add(
   407  				time.Duration(readTimeoutMs) * time.Millisecond))
   408  		} else {
   409  			conn.rawConnection.SetReadDeadline(infiniteTimeout)
   410  		}
   411  
   412  		f, err := conn.rawConnection.ReadFrame()
   413  		if err != nil {
   414  			return
   415  		}
   416  
   417  		if f == nil {
   418  			// heartbeat frame
   419  			continue
   420  		}
   421  
   422  		conn.inFrames <- f
   423  	}
   424  }
   425  
   426  func determineVersion(f *frame.Frame) (stomp.Version, error) {
   427  	if acceptVersion, ok := f.Header.Contains(frame.AcceptVersion); ok {
   428  		versions := strings.Split(acceptVersion, ",")
   429  		for _, supportedVersion := range []stomp.Version{stomp.V12, stomp.V11, stomp.V10} {
   430  			for _, v := range versions {
   431  				if v == supportedVersion.String() {
   432  					// return the highest supported version
   433  					return supportedVersion, nil
   434  				}
   435  			}
   436  		}
   437  	} else {
   438  		return stomp.V10, nil
   439  	}
   440  
   441  	var emptyVersion stomp.Version
   442  	return emptyVersion, unsupportedStompVersionError
   443  }
   444  
   445  func getHeartBeat(f *frame.Frame) (cx, cy time.Duration, err error) {
   446  	if heartBeat, ok := f.Header.Contains(frame.HeartBeat); ok {
   447  		return frame.ParseHeartBeat(heartBeat)
   448  	}
   449  	return 0, 0, nil
   450  }
   451  
   452  func (conn *stompConn) sendError(err error) {
   453  	errorFrame := frame.New(frame.ERROR,
   454  		frame.Message, err.Error())
   455  
   456  	conn.rawConnection.WriteFrame(errorFrame)
   457  }
   458  
   459  func (conn *stompConn) populateMessageIdHeader(f *frame.Frame) {
   460  	if f.Command == frame.MESSAGE {
   461  		// allocate the value of message-id for this frame
   462  		conn.currentMessageId++
   463  		messageId := strconv.FormatUint(conn.currentMessageId, 10)
   464  		f.Header.Set(frame.MessageId, messageId)
   465  		// remove the Ack header (if any) as we don't support those
   466  		f.Header.Del(frame.Ack)
   467  	}
   468  }