github.com/vmware/transport-go@v1.3.4/stompserver/stomp_connection.go (about) 1 // Copyright 2019-2020 VMware, Inc. 2 // SPDX-License-Identifier: BSD-2-Clause 3 4 package stompserver 5 6 import ( 7 "fmt" 8 "github.com/go-stomp/stomp/v3" 9 "github.com/go-stomp/stomp/v3/frame" 10 "github.com/google/uuid" 11 "log" 12 "strconv" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "time" 17 ) 18 19 type subscription struct { 20 id string 21 destination string 22 } 23 24 type StompConn interface { 25 // Return unique connection Id string 26 GetId() string 27 SendFrameToSubscription(f *frame.Frame, sub *subscription) 28 Close() 29 } 30 31 const ( 32 maxHeartBeatDuration = time.Duration(999999999) * time.Millisecond 33 ) 34 35 const ( 36 connecting int32 = iota 37 connected 38 closed 39 ) 40 41 type stompConn struct { 42 rawConnection RawConnection 43 state int32 44 version stomp.Version 45 inFrames chan *frame.Frame 46 outFrames chan *frame.Frame 47 readTimeoutMs int64 48 writeTimeout time.Duration 49 id string 50 events chan *ConnEvent 51 config StompConfig 52 subscriptions map[string]*subscription 53 currentMessageId uint64 54 closeOnce sync.Once 55 } 56 57 func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *ConnEvent) StompConn { 58 conn := &stompConn{ 59 rawConnection: rawConnection, 60 state: connecting, 61 inFrames: make(chan *frame.Frame, 32), 62 outFrames: make(chan *frame.Frame, 32), 63 config: config, 64 id: uuid.New().String(), 65 events: events, 66 subscriptions: make(map[string]*subscription), 67 } 68 69 go conn.run() 70 go conn.readInFrames() 71 72 return conn 73 } 74 75 func (conn *stompConn) SendFrameToSubscription(f *frame.Frame, sub *subscription) { 76 f.Header.Add(frame.Subscription, sub.id) 77 conn.outFrames <- f 78 } 79 80 func (conn *stompConn) Close() { 81 conn.closeOnce.Do(func() { 82 atomic.StoreInt32(&conn.state, closed) 83 conn.rawConnection.Close() 84 85 conn.events <- &ConnEvent{ 86 ConnId: conn.GetId(), 87 eventType: ConnectionClosed, 88 conn: conn, 89 } 90 }) 91 } 92 93 func (conn *stompConn) GetId() string { 94 return conn.id 95 } 96 97 func (conn *stompConn) run() { 98 defer conn.Close() 99 100 var timerChannel <-chan time.Time 101 var timer *time.Timer 102 103 for { 104 105 if atomic.LoadInt32(&conn.state) == closed { 106 return 107 } 108 109 if timer == nil && conn.writeTimeout > 0 { 110 timer = time.NewTimer(conn.writeTimeout) 111 timerChannel = timer.C 112 } 113 114 select { 115 case f, ok := <-conn.outFrames: 116 if !ok { 117 // close connection 118 return 119 } 120 121 // reset heart-beat timer 122 if timer != nil { 123 timer.Stop() 124 timer = nil 125 } 126 127 conn.populateMessageIdHeader(f) 128 129 // write the frame to the client 130 err := conn.rawConnection.WriteFrame(f) 131 if err != nil || f.Command == frame.ERROR { 132 return 133 } 134 135 case f, ok := <-conn.inFrames: 136 if !ok { 137 return 138 } 139 140 if err := conn.handleIncomingFrame(f); err != nil { 141 conn.sendError(err) 142 return 143 } 144 145 case _ = <-timerChannel: 146 // write a heart-beat 147 err := conn.rawConnection.WriteFrame(nil) 148 if err != nil { 149 return 150 } 151 if timer != nil { 152 timer.Stop() 153 timer = nil 154 } 155 } 156 } 157 } 158 159 func (conn *stompConn) handleIncomingFrame(f *frame.Frame) error { 160 switch f.Command { 161 162 case frame.CONNECT, frame.STOMP: 163 return conn.handleConnect(f) 164 165 case frame.DISCONNECT: 166 return conn.handleDisconnect(f) 167 168 case frame.SEND: 169 return conn.handleSend(f) 170 171 case frame.SUBSCRIBE: 172 return conn.handleSubscribe(f) 173 174 case frame.UNSUBSCRIBE: 175 return conn.handleUnsubscribe(f) 176 } 177 178 return unsupportedStompCommandError 179 } 180 181 // Returns true if the frame contains ANY of the specified 182 // headers 183 func containsHeader(f *frame.Frame, headers ...string) bool { 184 for _, h := range headers { 185 if _, ok := f.Header.Contains(h); ok { 186 return true 187 } 188 } 189 return false 190 } 191 192 func (conn *stompConn) handleConnect(f *frame.Frame) error { 193 if atomic.LoadInt32(&conn.state) == connected { 194 return unexpectedStompCommandError 195 } 196 197 if containsHeader(f, frame.Receipt) { 198 return invalidHeaderError 199 } 200 201 var err error 202 conn.version, err = determineVersion(f) 203 if err != nil { 204 log.Println("cannot determine version") 205 return err 206 } 207 208 if conn.version == stomp.V10 { 209 return unsupportedStompVersionError 210 } 211 212 cxDuration, cyDuration, err := getHeartBeat(f) 213 if err != nil { 214 log.Println("invalid heart-beat") 215 return err 216 } 217 218 min := time.Duration(conn.config.HeartBeat()) * time.Millisecond 219 if min > maxHeartBeatDuration { 220 min = maxHeartBeatDuration 221 } 222 223 // apply a minimum heartbeat 224 if cxDuration > 0 { 225 if min == 0 || cxDuration < min { 226 cxDuration = min 227 } 228 } 229 if cyDuration > 0 { 230 if min == 0 || cyDuration < min { 231 cyDuration = min 232 } 233 } 234 235 conn.writeTimeout = cyDuration 236 237 cx, cy := int64(cxDuration/time.Millisecond), int64(cyDuration/time.Millisecond) 238 atomic.StoreInt64(&conn.readTimeoutMs, cx) 239 240 response := frame.New(frame.CONNECTED, 241 frame.Version, string(conn.version), 242 frame.Server, "stompServer/0.0.1", 243 frame.HeartBeat, fmt.Sprintf("%d,%d", cy, cx)) 244 245 err = conn.rawConnection.WriteFrame(response) 246 if err != nil { 247 return err 248 } 249 250 atomic.StoreInt32(&conn.state, connected) 251 252 conn.events <- &ConnEvent{ 253 ConnId: conn.GetId(), 254 eventType: ConnectionEstablished, 255 conn: conn, 256 } 257 258 return nil 259 } 260 261 func (conn *stompConn) handleDisconnect(f *frame.Frame) error { 262 if atomic.LoadInt32(&conn.state) == connecting { 263 return notConnectedStompError 264 } 265 266 conn.sendReceiptResponse(f) 267 conn.Close() 268 269 return nil 270 } 271 272 func (conn *stompConn) handleSubscribe(f *frame.Frame) error { 273 switch atomic.LoadInt32(&conn.state) { 274 case connecting: 275 return notConnectedStompError 276 case closed: 277 return nil 278 } 279 280 subId, ok := f.Header.Contains(frame.Id) 281 if !ok { 282 return invalidSubscriptionError 283 } 284 285 dest, ok := f.Header.Contains(frame.Destination) 286 if !ok { 287 return invalidFrameError 288 } 289 290 if _, exists := conn.subscriptions[subId]; exists { 291 // subscription already exists 292 return nil 293 } 294 295 conn.subscriptions[subId] = &subscription{ 296 id: subId, 297 destination: dest, 298 } 299 300 conn.events <- &ConnEvent{ 301 ConnId: conn.GetId(), 302 eventType: SubscribeToTopic, 303 destination: dest, 304 conn: conn, 305 sub: conn.subscriptions[subId], 306 frame: f, 307 } 308 309 return nil 310 } 311 312 func (conn *stompConn) handleUnsubscribe(f *frame.Frame) error { 313 switch atomic.LoadInt32(&conn.state) { 314 case connecting: 315 return notConnectedStompError 316 case closed: 317 return nil 318 } 319 320 id, ok := f.Header.Contains(frame.Id) 321 if !ok { 322 return invalidSubscriptionError 323 } 324 325 conn.sendReceiptResponse(f) 326 327 sub, ok := conn.subscriptions[id] 328 if !ok { 329 // subscription already removed 330 return nil 331 } 332 333 // remove the subscription 334 delete(conn.subscriptions, id) 335 336 conn.events <- &ConnEvent{ 337 ConnId: conn.GetId(), 338 eventType: UnsubscribeFromTopic, 339 conn: conn, 340 sub: sub, 341 destination: sub.destination, 342 } 343 344 return nil 345 } 346 347 func (conn *stompConn) handleSend(f *frame.Frame) error { 348 switch atomic.LoadInt32(&conn.state) { 349 case connecting: 350 return notConnectedStompError 351 case closed: 352 return nil 353 } 354 355 // TODO: Remove if we start supporting transactions 356 if containsHeader(f, frame.Transaction) { 357 return unsupportedStompCommandError 358 } 359 360 // no destination triggers an error 361 dest, ok := f.Header.Contains(frame.Destination) 362 if !ok { 363 return invalidFrameError 364 } 365 366 // reject SENDing directly to non-request channels by clients 367 if !conn.config.IsAppRequestDestination(f.Header.Get(frame.Destination)) { 368 return invalidSendDestinationError 369 } 370 371 err := conn.sendReceiptResponse(f) 372 if err != nil { 373 return err 374 } 375 376 f.Command = frame.MESSAGE 377 conn.events <- &ConnEvent{ 378 ConnId: conn.GetId(), 379 eventType: IncomingMessage, 380 destination: dest, 381 frame: f, 382 conn: conn, 383 } 384 385 return nil 386 } 387 388 func (conn *stompConn) sendReceiptResponse(f *frame.Frame) error { 389 if receipt, ok := f.Header.Contains(frame.Receipt); ok { 390 f.Header.Del(frame.Receipt) 391 return conn.rawConnection.WriteFrame(frame.New(frame.RECEIPT, frame.ReceiptId, receipt)) 392 } 393 return nil 394 } 395 396 func (conn *stompConn) readInFrames() { 397 defer func() { 398 close(conn.inFrames) 399 }() 400 401 infiniteTimeout := time.Time{} 402 var readTimeoutMs int64 = 0 403 for { 404 readTimeoutMs = atomic.LoadInt64(&conn.readTimeoutMs) 405 if readTimeoutMs > 0 { 406 conn.rawConnection.SetReadDeadline(time.Now().Add( 407 time.Duration(readTimeoutMs) * time.Millisecond)) 408 } else { 409 conn.rawConnection.SetReadDeadline(infiniteTimeout) 410 } 411 412 f, err := conn.rawConnection.ReadFrame() 413 if err != nil { 414 return 415 } 416 417 if f == nil { 418 // heartbeat frame 419 continue 420 } 421 422 conn.inFrames <- f 423 } 424 } 425 426 func determineVersion(f *frame.Frame) (stomp.Version, error) { 427 if acceptVersion, ok := f.Header.Contains(frame.AcceptVersion); ok { 428 versions := strings.Split(acceptVersion, ",") 429 for _, supportedVersion := range []stomp.Version{stomp.V12, stomp.V11, stomp.V10} { 430 for _, v := range versions { 431 if v == supportedVersion.String() { 432 // return the highest supported version 433 return supportedVersion, nil 434 } 435 } 436 } 437 } else { 438 return stomp.V10, nil 439 } 440 441 var emptyVersion stomp.Version 442 return emptyVersion, unsupportedStompVersionError 443 } 444 445 func getHeartBeat(f *frame.Frame) (cx, cy time.Duration, err error) { 446 if heartBeat, ok := f.Header.Contains(frame.HeartBeat); ok { 447 return frame.ParseHeartBeat(heartBeat) 448 } 449 return 0, 0, nil 450 } 451 452 func (conn *stompConn) sendError(err error) { 453 errorFrame := frame.New(frame.ERROR, 454 frame.Message, err.Error()) 455 456 conn.rawConnection.WriteFrame(errorFrame) 457 } 458 459 func (conn *stompConn) populateMessageIdHeader(f *frame.Frame) { 460 if f.Command == frame.MESSAGE { 461 // allocate the value of message-id for this frame 462 conn.currentMessageId++ 463 messageId := strconv.FormatUint(conn.currentMessageId, 10) 464 f.Header.Set(frame.MessageId, messageId) 465 // remove the Ack header (if any) as we don't support those 466 f.Header.Del(frame.Ack) 467 } 468 }