github.com/vmware/transport-go@v1.3.4/stompserver/server.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package stompserver
     5  
     6  import (
     7  	"github.com/go-stomp/stomp/v3/frame"
     8  	"log"
     9  	"strconv"
    10  	"sync"
    11  )
    12  
    13  type SubscribeHandlerFunction func(conId string, subId string, destination string, frame *frame.Frame)
    14  
    15  type UnsubscribeHandlerFunction func(conId string, subId string, destination string)
    16  
    17  type ApplicationRequestHandlerFunction func(destination string, message []byte, connectionId string)
    18  
    19  type StompServer interface {
    20  	// starts the server
    21  	Start()
    22  	// stops the server
    23  	Stop()
    24  	// sends a message to a given stomp topic destination
    25  	SendMessage(destination string, messageBody []byte)
    26  	// sends a message to a single connection client
    27  	SendMessageToClient(connectionId string, destination string, messageBody []byte)
    28  	// registers a callback for stomp subscribe events
    29  	OnSubscribeEvent(callback SubscribeHandlerFunction)
    30  	// registers a callback for stomp unsubscribe events
    31  	OnUnsubscribeEvent(callback UnsubscribeHandlerFunction)
    32  	// registers a callback for application requests
    33  	OnApplicationRequest(callback ApplicationRequestHandlerFunction)
    34  	// SetConnectionEventCallback is used to set up a callback when certain STOMP session events happen
    35  	// such as ConnectionStarting, ConnectionClosed, SubscribeToTopic, UnsubscribeFromTopic and IncomingMessage.
    36  	SetConnectionEventCallback(connEventType StompSessionEventType, cb func(connEvent *ConnEvent))
    37  }
    38  
    39  type StompSessionEventType int
    40  
    41  const (
    42  	ConnectionStarting StompSessionEventType = iota
    43  	ConnectionEstablished
    44  	ConnectionClosed
    45  	SubscribeToTopic
    46  	UnsubscribeFromTopic
    47  	IncomingMessage
    48  )
    49  
    50  type ConnEvent struct {
    51  	ConnId      string
    52  	eventType   StompSessionEventType
    53  	conn        StompConn
    54  	destination string
    55  	sub         *subscription
    56  	frame       *frame.Frame
    57  }
    58  
    59  type apiEventType int
    60  
    61  const (
    62  	closeServer apiEventType = iota
    63  	sendMessage
    64  	sendPrivateMessage
    65  )
    66  
    67  type apiEvent struct {
    68  	eventType   apiEventType
    69  	connId      string
    70  	frame       *frame.Frame
    71  	destination string
    72  }
    73  
    74  type connSubscriptions struct {
    75  	conn          StompConn
    76  	subscriptions map[string]*subscription
    77  }
    78  
    79  func newConnSubscriptions(conn StompConn) *connSubscriptions {
    80  	return &connSubscriptions{
    81  		conn:          conn,
    82  		subscriptions: make(map[string]*subscription),
    83  	}
    84  }
    85  
    86  type stompServer struct {
    87  	connectionListener          RawConnectionListener
    88  	connectionEvents            chan *ConnEvent
    89  	connectionEventCallbacks    map[StompSessionEventType]func(event *ConnEvent)
    90  	apiEvents                   chan *apiEvent
    91  	running                     bool
    92  	connectionsMap              map[string]StompConn
    93  	subscriptionsMap            map[string]map[string]*connSubscriptions
    94  	config                      StompConfig
    95  	callbackLock                sync.RWMutex
    96  	subscribeCallbacks          []SubscribeHandlerFunction
    97  	unsubscribeCallbacks        []UnsubscribeHandlerFunction
    98  	applicationRequestCallbacks []ApplicationRequestHandlerFunction
    99  }
   100  
   101  func NewStompServer(listener RawConnectionListener, config StompConfig) StompServer {
   102  	server := &stompServer{
   103  		config:                      config,
   104  		connectionListener:          listener,
   105  		apiEvents:                   make(chan *apiEvent, 32),
   106  		connectionsMap:              make(map[string]StompConn),
   107  		connectionEvents:            make(chan *ConnEvent, 64),
   108  		connectionEventCallbacks:    make(map[StompSessionEventType]func(event *ConnEvent)),
   109  		subscriptionsMap:            make(map[string]map[string]*connSubscriptions),
   110  		subscribeCallbacks:          make([]SubscribeHandlerFunction, 0),
   111  		unsubscribeCallbacks:        make([]UnsubscribeHandlerFunction, 0),
   112  		applicationRequestCallbacks: make([]ApplicationRequestHandlerFunction, 0),
   113  	}
   114  
   115  	return server
   116  }
   117  
   118  func (s *stompServer) OnSubscribeEvent(callback SubscribeHandlerFunction) {
   119  	s.callbackLock.Lock()
   120  	defer s.callbackLock.Unlock()
   121  
   122  	s.subscribeCallbacks = append(s.subscribeCallbacks, callback)
   123  }
   124  
   125  func (s *stompServer) OnUnsubscribeEvent(callback UnsubscribeHandlerFunction) {
   126  	s.callbackLock.Lock()
   127  	defer s.callbackLock.Unlock()
   128  
   129  	s.unsubscribeCallbacks = append(s.unsubscribeCallbacks, callback)
   130  }
   131  
   132  func (s *stompServer) OnApplicationRequest(callback ApplicationRequestHandlerFunction) {
   133  	s.callbackLock.Lock()
   134  	defer s.callbackLock.Unlock()
   135  
   136  	s.applicationRequestCallbacks = append(s.applicationRequestCallbacks, callback)
   137  }
   138  
   139  func (s *stompServer) SendMessage(destination string, messageBody []byte) {
   140  
   141  	// create send frame.
   142  	f := frame.New(frame.MESSAGE,
   143  		frame.Destination, destination,
   144  		frame.ContentLength, strconv.Itoa(len(messageBody)),
   145  		frame.ContentType, "application/json;charset=UTF-8")
   146  
   147  	f.Body = messageBody
   148  
   149  	s.apiEvents <- &apiEvent{
   150  		eventType:   sendMessage,
   151  		destination: destination,
   152  		frame:       f,
   153  	}
   154  }
   155  
   156  func (s *stompServer) SendMessageToClient(connectionId string, destination string, messageBody []byte) {
   157  
   158  	// create send frame.
   159  	f := frame.New(frame.MESSAGE,
   160  		frame.Destination, destination,
   161  		frame.ContentLength, strconv.Itoa(len(messageBody)),
   162  		frame.ContentType, "application/json;charset=UTF-8")
   163  
   164  	f.Body = messageBody
   165  
   166  	s.apiEvents <- &apiEvent{
   167  		eventType:   sendPrivateMessage,
   168  		destination: destination,
   169  		frame:       f,
   170  		connId:      connectionId,
   171  	}
   172  }
   173  
   174  func (s *stompServer) SetConnectionEventCallback(connEventType StompSessionEventType, cb func(connEvent *ConnEvent)) {
   175  	s.callbackLock.Lock()
   176  	defer s.callbackLock.Unlock()
   177  	s.connectionEventCallbacks[connEventType] = cb
   178  }
   179  
   180  func (s *stompServer) Start() {
   181  	if s.running {
   182  		return
   183  	}
   184  
   185  	s.running = true
   186  	go s.waitForConnections()
   187  	s.run()
   188  }
   189  
   190  func (s *stompServer) Stop() {
   191  	if s.running {
   192  		s.running = false
   193  		s.apiEvents <- &apiEvent{
   194  			eventType: closeServer,
   195  		}
   196  	}
   197  }
   198  
   199  func (s *stompServer) waitForConnections() {
   200  	for {
   201  		if !s.running {
   202  			return
   203  		}
   204  
   205  		rawConn, err := s.connectionListener.Accept()
   206  		if err != nil {
   207  			if s.running {
   208  				log.Println("Failed to establish client connection:", err)
   209  			}
   210  			continue
   211  		}
   212  
   213  		c := NewStompConn(rawConn, s.config, s.connectionEvents)
   214  
   215  		s.connectionEvents <- &ConnEvent{
   216  			ConnId:    c.GetId(),
   217  			conn:      c,
   218  			eventType: ConnectionStarting,
   219  		}
   220  	}
   221  }
   222  
   223  func (s *stompServer) run() {
   224  	for {
   225  		select {
   226  
   227  		case apiEvent, _ := <-s.apiEvents:
   228  			if apiEvent.eventType == closeServer {
   229  				s.connectionListener.Close()
   230  				// close all open connections
   231  				for _, c := range s.connectionsMap {
   232  					c.Close()
   233  				}
   234  				s.connectionsMap = make(map[string]StompConn)
   235  				return
   236  			} else if apiEvent.eventType == sendMessage {
   237  				s.sendFrame(apiEvent.destination, apiEvent.frame)
   238  			} else if apiEvent.eventType == sendPrivateMessage {
   239  				s.sendFrameToClient(apiEvent.connId, apiEvent.destination, apiEvent.frame)
   240  			}
   241  
   242  		case e, _ := <-s.connectionEvents:
   243  			s.handleConnectionEvent(e)
   244  		}
   245  	}
   246  }
   247  
   248  func (s *stompServer) handleConnectionEvent(e *ConnEvent) {
   249  
   250  	s.callbackLock.RLock()
   251  	defer s.callbackLock.RUnlock()
   252  
   253  	switch e.eventType {
   254  	case ConnectionStarting:
   255  		s.connectionsMap[e.conn.GetId()] = e.conn
   256  		if fn, exists := s.connectionEventCallbacks[ConnectionStarting]; exists {
   257  			fn(e)
   258  		}
   259  
   260  	case ConnectionClosed:
   261  		delete(s.connectionsMap, e.conn.GetId())
   262  		for _, connSubscriptions := range s.subscriptionsMap {
   263  			conSub, ok := connSubscriptions[e.conn.GetId()]
   264  			if ok {
   265  				delete(connSubscriptions, e.conn.GetId())
   266  				for _, sub := range conSub.subscriptions {
   267  					for _, callback := range s.unsubscribeCallbacks {
   268  						callback(e.conn.GetId(), sub.id, sub.destination)
   269  					}
   270  				}
   271  			}
   272  		}
   273  		if fn, exists := s.connectionEventCallbacks[ConnectionClosed]; exists {
   274  			fn(e)
   275  		}
   276  
   277  	case SubscribeToTopic:
   278  		subsMap, ok := s.subscriptionsMap[e.destination]
   279  		if !ok {
   280  			subsMap = make(map[string]*connSubscriptions)
   281  			s.subscriptionsMap[e.destination] = subsMap
   282  		}
   283  		var conSub *connSubscriptions
   284  		conSub, ok = subsMap[e.conn.GetId()]
   285  		if !ok {
   286  			conSub = newConnSubscriptions(e.conn)
   287  			subsMap[e.conn.GetId()] = conSub
   288  		}
   289  		conSub.subscriptions[e.sub.id] = e.sub
   290  
   291  		// notify listeners
   292  		for _, callback := range s.subscribeCallbacks {
   293  			callback(e.conn.GetId(), e.sub.id, e.destination, e.frame)
   294  		}
   295  		if fn, exists := s.connectionEventCallbacks[SubscribeToTopic]; exists {
   296  			fn(e)
   297  		}
   298  
   299  	case UnsubscribeFromTopic:
   300  		subs, ok := s.subscriptionsMap[e.destination]
   301  		if ok {
   302  			var conSub *connSubscriptions
   303  			conSub, ok = subs[e.conn.GetId()]
   304  			if ok {
   305  				_, ok = conSub.subscriptions[e.sub.id]
   306  				if ok {
   307  					delete(conSub.subscriptions, e.sub.id)
   308  					// notify listeners
   309  					for _, callback := range s.unsubscribeCallbacks {
   310  						callback(e.conn.GetId(), e.sub.id, e.destination)
   311  					}
   312  				}
   313  			}
   314  		}
   315  		if fn, exists := s.connectionEventCallbacks[UnsubscribeFromTopic]; exists {
   316  			fn(e)
   317  		}
   318  
   319  	case IncomingMessage:
   320  		if s.config.IsAppRequestDestination(e.destination) && e.conn != nil {
   321  			// notify app listeners
   322  			for _, callback := range s.applicationRequestCallbacks {
   323  				callback(e.destination, e.frame.Body, e.conn.GetId())
   324  			}
   325  		}
   326  		if fn, exists := s.connectionEventCallbacks[IncomingMessage]; exists {
   327  			fn(e)
   328  		}
   329  	}
   330  }
   331  
   332  func (s *stompServer) sendFrame(dest string, f *frame.Frame) {
   333  	subsMap, ok := s.subscriptionsMap[dest]
   334  	if ok {
   335  		for _, connSub := range subsMap {
   336  			for _, sub := range connSub.subscriptions {
   337  				connSub.conn.SendFrameToSubscription(f.Clone(), sub)
   338  			}
   339  		}
   340  	}
   341  }
   342  
   343  func (s *stompServer) sendFrameToClient(conId string, dest string, f *frame.Frame) {
   344  	subsMap, ok := s.subscriptionsMap[dest]
   345  	if ok {
   346  		connSubscriptions, ok := subsMap[conId]
   347  		if ok {
   348  			for _, sub := range connSubscriptions.subscriptions {
   349  				connSubscriptions.conn.SendFrameToSubscription(f.Clone(), sub)
   350  			}
   351  		}
   352  	}
   353  }