github.com/vmware/transport-go@v1.3.4/stompserver/server.go (about) 1 // Copyright 2019-2020 VMware, Inc. 2 // SPDX-License-Identifier: BSD-2-Clause 3 4 package stompserver 5 6 import ( 7 "github.com/go-stomp/stomp/v3/frame" 8 "log" 9 "strconv" 10 "sync" 11 ) 12 13 type SubscribeHandlerFunction func(conId string, subId string, destination string, frame *frame.Frame) 14 15 type UnsubscribeHandlerFunction func(conId string, subId string, destination string) 16 17 type ApplicationRequestHandlerFunction func(destination string, message []byte, connectionId string) 18 19 type StompServer interface { 20 // starts the server 21 Start() 22 // stops the server 23 Stop() 24 // sends a message to a given stomp topic destination 25 SendMessage(destination string, messageBody []byte) 26 // sends a message to a single connection client 27 SendMessageToClient(connectionId string, destination string, messageBody []byte) 28 // registers a callback for stomp subscribe events 29 OnSubscribeEvent(callback SubscribeHandlerFunction) 30 // registers a callback for stomp unsubscribe events 31 OnUnsubscribeEvent(callback UnsubscribeHandlerFunction) 32 // registers a callback for application requests 33 OnApplicationRequest(callback ApplicationRequestHandlerFunction) 34 // SetConnectionEventCallback is used to set up a callback when certain STOMP session events happen 35 // such as ConnectionStarting, ConnectionClosed, SubscribeToTopic, UnsubscribeFromTopic and IncomingMessage. 36 SetConnectionEventCallback(connEventType StompSessionEventType, cb func(connEvent *ConnEvent)) 37 } 38 39 type StompSessionEventType int 40 41 const ( 42 ConnectionStarting StompSessionEventType = iota 43 ConnectionEstablished 44 ConnectionClosed 45 SubscribeToTopic 46 UnsubscribeFromTopic 47 IncomingMessage 48 ) 49 50 type ConnEvent struct { 51 ConnId string 52 eventType StompSessionEventType 53 conn StompConn 54 destination string 55 sub *subscription 56 frame *frame.Frame 57 } 58 59 type apiEventType int 60 61 const ( 62 closeServer apiEventType = iota 63 sendMessage 64 sendPrivateMessage 65 ) 66 67 type apiEvent struct { 68 eventType apiEventType 69 connId string 70 frame *frame.Frame 71 destination string 72 } 73 74 type connSubscriptions struct { 75 conn StompConn 76 subscriptions map[string]*subscription 77 } 78 79 func newConnSubscriptions(conn StompConn) *connSubscriptions { 80 return &connSubscriptions{ 81 conn: conn, 82 subscriptions: make(map[string]*subscription), 83 } 84 } 85 86 type stompServer struct { 87 connectionListener RawConnectionListener 88 connectionEvents chan *ConnEvent 89 connectionEventCallbacks map[StompSessionEventType]func(event *ConnEvent) 90 apiEvents chan *apiEvent 91 running bool 92 connectionsMap map[string]StompConn 93 subscriptionsMap map[string]map[string]*connSubscriptions 94 config StompConfig 95 callbackLock sync.RWMutex 96 subscribeCallbacks []SubscribeHandlerFunction 97 unsubscribeCallbacks []UnsubscribeHandlerFunction 98 applicationRequestCallbacks []ApplicationRequestHandlerFunction 99 } 100 101 func NewStompServer(listener RawConnectionListener, config StompConfig) StompServer { 102 server := &stompServer{ 103 config: config, 104 connectionListener: listener, 105 apiEvents: make(chan *apiEvent, 32), 106 connectionsMap: make(map[string]StompConn), 107 connectionEvents: make(chan *ConnEvent, 64), 108 connectionEventCallbacks: make(map[StompSessionEventType]func(event *ConnEvent)), 109 subscriptionsMap: make(map[string]map[string]*connSubscriptions), 110 subscribeCallbacks: make([]SubscribeHandlerFunction, 0), 111 unsubscribeCallbacks: make([]UnsubscribeHandlerFunction, 0), 112 applicationRequestCallbacks: make([]ApplicationRequestHandlerFunction, 0), 113 } 114 115 return server 116 } 117 118 func (s *stompServer) OnSubscribeEvent(callback SubscribeHandlerFunction) { 119 s.callbackLock.Lock() 120 defer s.callbackLock.Unlock() 121 122 s.subscribeCallbacks = append(s.subscribeCallbacks, callback) 123 } 124 125 func (s *stompServer) OnUnsubscribeEvent(callback UnsubscribeHandlerFunction) { 126 s.callbackLock.Lock() 127 defer s.callbackLock.Unlock() 128 129 s.unsubscribeCallbacks = append(s.unsubscribeCallbacks, callback) 130 } 131 132 func (s *stompServer) OnApplicationRequest(callback ApplicationRequestHandlerFunction) { 133 s.callbackLock.Lock() 134 defer s.callbackLock.Unlock() 135 136 s.applicationRequestCallbacks = append(s.applicationRequestCallbacks, callback) 137 } 138 139 func (s *stompServer) SendMessage(destination string, messageBody []byte) { 140 141 // create send frame. 142 f := frame.New(frame.MESSAGE, 143 frame.Destination, destination, 144 frame.ContentLength, strconv.Itoa(len(messageBody)), 145 frame.ContentType, "application/json;charset=UTF-8") 146 147 f.Body = messageBody 148 149 s.apiEvents <- &apiEvent{ 150 eventType: sendMessage, 151 destination: destination, 152 frame: f, 153 } 154 } 155 156 func (s *stompServer) SendMessageToClient(connectionId string, destination string, messageBody []byte) { 157 158 // create send frame. 159 f := frame.New(frame.MESSAGE, 160 frame.Destination, destination, 161 frame.ContentLength, strconv.Itoa(len(messageBody)), 162 frame.ContentType, "application/json;charset=UTF-8") 163 164 f.Body = messageBody 165 166 s.apiEvents <- &apiEvent{ 167 eventType: sendPrivateMessage, 168 destination: destination, 169 frame: f, 170 connId: connectionId, 171 } 172 } 173 174 func (s *stompServer) SetConnectionEventCallback(connEventType StompSessionEventType, cb func(connEvent *ConnEvent)) { 175 s.callbackLock.Lock() 176 defer s.callbackLock.Unlock() 177 s.connectionEventCallbacks[connEventType] = cb 178 } 179 180 func (s *stompServer) Start() { 181 if s.running { 182 return 183 } 184 185 s.running = true 186 go s.waitForConnections() 187 s.run() 188 } 189 190 func (s *stompServer) Stop() { 191 if s.running { 192 s.running = false 193 s.apiEvents <- &apiEvent{ 194 eventType: closeServer, 195 } 196 } 197 } 198 199 func (s *stompServer) waitForConnections() { 200 for { 201 if !s.running { 202 return 203 } 204 205 rawConn, err := s.connectionListener.Accept() 206 if err != nil { 207 if s.running { 208 log.Println("Failed to establish client connection:", err) 209 } 210 continue 211 } 212 213 c := NewStompConn(rawConn, s.config, s.connectionEvents) 214 215 s.connectionEvents <- &ConnEvent{ 216 ConnId: c.GetId(), 217 conn: c, 218 eventType: ConnectionStarting, 219 } 220 } 221 } 222 223 func (s *stompServer) run() { 224 for { 225 select { 226 227 case apiEvent, _ := <-s.apiEvents: 228 if apiEvent.eventType == closeServer { 229 s.connectionListener.Close() 230 // close all open connections 231 for _, c := range s.connectionsMap { 232 c.Close() 233 } 234 s.connectionsMap = make(map[string]StompConn) 235 return 236 } else if apiEvent.eventType == sendMessage { 237 s.sendFrame(apiEvent.destination, apiEvent.frame) 238 } else if apiEvent.eventType == sendPrivateMessage { 239 s.sendFrameToClient(apiEvent.connId, apiEvent.destination, apiEvent.frame) 240 } 241 242 case e, _ := <-s.connectionEvents: 243 s.handleConnectionEvent(e) 244 } 245 } 246 } 247 248 func (s *stompServer) handleConnectionEvent(e *ConnEvent) { 249 250 s.callbackLock.RLock() 251 defer s.callbackLock.RUnlock() 252 253 switch e.eventType { 254 case ConnectionStarting: 255 s.connectionsMap[e.conn.GetId()] = e.conn 256 if fn, exists := s.connectionEventCallbacks[ConnectionStarting]; exists { 257 fn(e) 258 } 259 260 case ConnectionClosed: 261 delete(s.connectionsMap, e.conn.GetId()) 262 for _, connSubscriptions := range s.subscriptionsMap { 263 conSub, ok := connSubscriptions[e.conn.GetId()] 264 if ok { 265 delete(connSubscriptions, e.conn.GetId()) 266 for _, sub := range conSub.subscriptions { 267 for _, callback := range s.unsubscribeCallbacks { 268 callback(e.conn.GetId(), sub.id, sub.destination) 269 } 270 } 271 } 272 } 273 if fn, exists := s.connectionEventCallbacks[ConnectionClosed]; exists { 274 fn(e) 275 } 276 277 case SubscribeToTopic: 278 subsMap, ok := s.subscriptionsMap[e.destination] 279 if !ok { 280 subsMap = make(map[string]*connSubscriptions) 281 s.subscriptionsMap[e.destination] = subsMap 282 } 283 var conSub *connSubscriptions 284 conSub, ok = subsMap[e.conn.GetId()] 285 if !ok { 286 conSub = newConnSubscriptions(e.conn) 287 subsMap[e.conn.GetId()] = conSub 288 } 289 conSub.subscriptions[e.sub.id] = e.sub 290 291 // notify listeners 292 for _, callback := range s.subscribeCallbacks { 293 callback(e.conn.GetId(), e.sub.id, e.destination, e.frame) 294 } 295 if fn, exists := s.connectionEventCallbacks[SubscribeToTopic]; exists { 296 fn(e) 297 } 298 299 case UnsubscribeFromTopic: 300 subs, ok := s.subscriptionsMap[e.destination] 301 if ok { 302 var conSub *connSubscriptions 303 conSub, ok = subs[e.conn.GetId()] 304 if ok { 305 _, ok = conSub.subscriptions[e.sub.id] 306 if ok { 307 delete(conSub.subscriptions, e.sub.id) 308 // notify listeners 309 for _, callback := range s.unsubscribeCallbacks { 310 callback(e.conn.GetId(), e.sub.id, e.destination) 311 } 312 } 313 } 314 } 315 if fn, exists := s.connectionEventCallbacks[UnsubscribeFromTopic]; exists { 316 fn(e) 317 } 318 319 case IncomingMessage: 320 if s.config.IsAppRequestDestination(e.destination) && e.conn != nil { 321 // notify app listeners 322 for _, callback := range s.applicationRequestCallbacks { 323 callback(e.destination, e.frame.Body, e.conn.GetId()) 324 } 325 } 326 if fn, exists := s.connectionEventCallbacks[IncomingMessage]; exists { 327 fn(e) 328 } 329 } 330 } 331 332 func (s *stompServer) sendFrame(dest string, f *frame.Frame) { 333 subsMap, ok := s.subscriptionsMap[dest] 334 if ok { 335 for _, connSub := range subsMap { 336 for _, sub := range connSub.subscriptions { 337 connSub.conn.SendFrameToSubscription(f.Clone(), sub) 338 } 339 } 340 } 341 } 342 343 func (s *stompServer) sendFrameToClient(conId string, dest string, f *frame.Frame) { 344 subsMap, ok := s.subscriptionsMap[dest] 345 if ok { 346 connSubscriptions, ok := subsMap[conId] 347 if ok { 348 for _, sub := range connSubscriptions.subscriptions { 349 connSubscriptions.conn.SendFrameToSubscription(f.Clone(), sub) 350 } 351 } 352 } 353 }