github.com/vmware/transport-go@v1.3.4/bus/fabric_endpoint.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package bus
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"github.com/go-stomp/stomp/v3/frame"
    10  	"github.com/vmware/transport-go/log"
    11  	"github.com/vmware/transport-go/model"
    12  	"github.com/vmware/transport-go/stompserver"
    13  	"strings"
    14  	"sync"
    15  )
    16  
    17  const (
    18  	STOMP_SESSION_NOTIFY_CHANNEL = TRANSPORT_INTERNAL_CHANNEL_PREFIX + "stomp-session-notify"
    19  )
    20  
    21  type EndpointConfig struct {
    22  	// Prefix for public topics e.g. "/topic"
    23  	TopicPrefix string
    24  	// Prefix for user queues e.g. "/user/queue"
    25  	UserQueuePrefix string
    26  	// Prefix used for public application requests e.g. "/pub"
    27  	AppRequestPrefix string
    28  	// Prefix used for "private" application requests e.g. "/pub/queue"
    29  	// Requests sent to destinations prefixed with the AppRequestQueuePrefix
    30  	// should generate responses sent to single client queue.
    31  	// E.g. if a client sends a request to the "/pub/queue/sample-channel" destination
    32  	// the application should sent the response only to this client on the
    33  	// "/user/queue/sample-channel" destination.
    34  	// This behavior will mimic the Spring SimpleMessageBroker implementation.
    35  	AppRequestQueuePrefix string
    36  	Heartbeat             int64
    37  }
    38  
    39  func (ec *EndpointConfig) validate() error {
    40  	if ec.TopicPrefix == "" || !strings.HasPrefix(ec.TopicPrefix, "/") {
    41  		return fmt.Errorf("invalid TopicPrefix")
    42  	}
    43  
    44  	if ec.AppRequestQueuePrefix != "" && ec.UserQueuePrefix == "" {
    45  		return fmt.Errorf("missing UserQueuePrefix")
    46  	}
    47  
    48  	return nil
    49  }
    50  
    51  type FabricEndpoint interface {
    52  	Start()
    53  	Stop()
    54  }
    55  
    56  type channelMapping struct {
    57  	subs        map[string]bool
    58  	handler     MessageHandler
    59  	autoCreated bool
    60  }
    61  
    62  type StompSessionEvent struct {
    63  	Id        string
    64  	EventType stompserver.StompSessionEventType
    65  }
    66  
    67  type fabricEndpoint struct {
    68  	server       stompserver.StompServer
    69  	bus          EventBus
    70  	config       EndpointConfig
    71  	chanLock     sync.RWMutex
    72  	chanMappings map[string]*channelMapping
    73  }
    74  
    75  func addPrefixIfNotEmpty(s string, prefix string) string {
    76  	if s != "" && !strings.HasSuffix(s, prefix) {
    77  		return s + prefix
    78  	}
    79  	return s
    80  }
    81  
    82  func newFabricEndpoint(bus EventBus,
    83  	conListener stompserver.RawConnectionListener, config EndpointConfig) FabricEndpoint {
    84  
    85  	config.TopicPrefix = addPrefixIfNotEmpty(config.TopicPrefix, "/")
    86  	config.AppRequestPrefix = addPrefixIfNotEmpty(config.AppRequestPrefix, "/")
    87  	config.AppRequestQueuePrefix = addPrefixIfNotEmpty(config.AppRequestQueuePrefix, "/")
    88  	config.UserQueuePrefix = addPrefixIfNotEmpty(config.UserQueuePrefix, "/")
    89  
    90  	stompConf := stompserver.NewStompConfig(config.Heartbeat,
    91  		[]string{config.AppRequestPrefix, config.AppRequestQueuePrefix})
    92  
    93  	fabricEndpoint := &fabricEndpoint{
    94  		server:       stompserver.NewStompServer(conListener, stompConf),
    95  		config:       config,
    96  		bus:          bus,
    97  		chanMappings: make(map[string]*channelMapping),
    98  	}
    99  
   100  	fabricEndpoint.initHandlers()
   101  	return fabricEndpoint
   102  }
   103  
   104  func (fe *fabricEndpoint) Start() {
   105  	fe.server.SetConnectionEventCallback(stompserver.ConnectionStarting, func(connEvent *stompserver.ConnEvent) {
   106  		busInstance.SendResponseMessage(STOMP_SESSION_NOTIFY_CHANNEL, &StompSessionEvent{
   107  			Id:        connEvent.ConnId,
   108  			EventType: stompserver.ConnectionStarting,
   109  		}, nil)
   110  	})
   111  	fe.server.SetConnectionEventCallback(stompserver.ConnectionClosed, func(connEvent *stompserver.ConnEvent) {
   112  		busInstance.SendResponseMessage(STOMP_SESSION_NOTIFY_CHANNEL, &StompSessionEvent{
   113  			Id:        connEvent.ConnId,
   114  			EventType: stompserver.ConnectionClosed,
   115  		}, nil)
   116  	})
   117  	fe.server.SetConnectionEventCallback(stompserver.UnsubscribeFromTopic, func(connEvent *stompserver.ConnEvent) {
   118  		busInstance.SendResponseMessage(STOMP_SESSION_NOTIFY_CHANNEL, &StompSessionEvent{
   119  			Id:        connEvent.ConnId,
   120  			EventType: stompserver.UnsubscribeFromTopic,
   121  		}, nil)
   122  	})
   123  	fe.server.Start()
   124  }
   125  
   126  func (fe *fabricEndpoint) Stop() {
   127  	fe.server.Stop()
   128  }
   129  
   130  func (fe *fabricEndpoint) initHandlers() {
   131  	fe.server.OnApplicationRequest(fe.bridgeMessage)
   132  	fe.server.OnSubscribeEvent(fe.addSubscription)
   133  	fe.server.OnUnsubscribeEvent(fe.removeSubscription)
   134  }
   135  
   136  func (fe *fabricEndpoint) addSubscription(
   137  	conId string, subId string, destination string, frame *frame.Frame) {
   138  
   139  	channelName, ok := fe.getChannelNameFromSubscription(destination)
   140  	if !ok {
   141  		return
   142  	}
   143  
   144  	// if destination is a protected channel do not establish a subscription
   145  	// (we don't want any clients to be sending messages to internal channels)
   146  	if isProtectedDestination(channelName) {
   147  		return
   148  	}
   149  
   150  	fe.chanLock.Lock()
   151  	defer fe.chanLock.Unlock()
   152  
   153  	chanMap, ok := fe.chanMappings[channelName]
   154  	if !ok {
   155  		messageHandler, err := fe.bus.ListenStream(channelName)
   156  		var autoCreated = false
   157  		if messageHandler == nil || err != nil {
   158  			fe.bus.GetChannelManager().CreateChannel(channelName)
   159  			messageHandler, err = fe.bus.ListenStream(channelName)
   160  			if messageHandler == nil || err != nil {
   161  				log.Warn("Unable to auto-create channel for destination: %s", destination)
   162  				return
   163  			}
   164  			autoCreated = true
   165  		}
   166  		messageHandler.Handle(
   167  			func(message *model.Message) {
   168  				data, err := marshalMessagePayload(message)
   169  				if err == nil {
   170  					resp, ok := convertPayloadToResponseObj(message)
   171  					if ok && resp != nil && resp.BrokerDestination != nil {
   172  						fe.server.SendMessageToClient(
   173  							resp.BrokerDestination.ConnectionId,
   174  							resp.BrokerDestination.Destination,
   175  							data)
   176  					} else {
   177  						fe.server.SendMessage(fe.config.TopicPrefix+channelName, data)
   178  					}
   179  				}
   180  			},
   181  			func(e error) {
   182  				fe.server.SendMessage(destination, []byte(e.Error()))
   183  			})
   184  
   185  		chanMap = &channelMapping{
   186  			subs:        make(map[string]bool),
   187  			handler:     messageHandler,
   188  			autoCreated: autoCreated,
   189  		}
   190  
   191  		fe.chanMappings[channelName] = chanMap
   192  	}
   193  	chanMap.subs[conId+"#"+subId] = true
   194  	fe.bus.SendMonitorEvent(FabricEndpointSubscribeEvt, channelName, nil)
   195  }
   196  
   197  func convertPayloadToResponseObj(message *model.Message) (*model.Response, bool) {
   198  	var resp model.Response
   199  	var ok bool
   200  
   201  	resp, ok = message.Payload.(model.Response)
   202  	if ok {
   203  		return &resp, true
   204  	}
   205  
   206  	var respPtr *model.Response
   207  	respPtr, ok = message.Payload.(*model.Response)
   208  	if ok {
   209  		return respPtr, true
   210  	}
   211  
   212  	return nil, false
   213  }
   214  
   215  func marshalMessagePayload(message *model.Message) ([]byte, error) {
   216  	// don't marshal string and []byte payloads
   217  	stringPayload, ok := message.Payload.(string)
   218  	if ok {
   219  		return []byte(stringPayload), nil
   220  	}
   221  	bytePayload, ok := message.Payload.([]byte)
   222  	if ok {
   223  		return bytePayload, nil
   224  	}
   225  	// encode the message payload as JSON
   226  	return json.Marshal(message.Payload)
   227  }
   228  
   229  func (fe *fabricEndpoint) removeSubscription(conId string, subId string, destination string) {
   230  
   231  	channelName, ok := fe.getChannelNameFromSubscription(destination)
   232  	if !ok {
   233  		return
   234  	}
   235  
   236  	fe.chanLock.Lock()
   237  	defer fe.chanLock.Unlock()
   238  
   239  	chanMap, ok := fe.chanMappings[channelName]
   240  	if ok {
   241  		mappingId := conId + "#" + subId
   242  		if chanMap.subs[mappingId] {
   243  			delete(chanMap.subs, mappingId)
   244  			if len(chanMap.subs) == 0 {
   245  				// if this was the last subscription to the channel,
   246  				// close the message handler and remove the channel mapping
   247  				chanMap.handler.Close()
   248  				delete(fe.chanMappings, channelName)
   249  				if chanMap.autoCreated {
   250  					fe.bus.GetChannelManager().DestroyChannel(channelName)
   251  				}
   252  			}
   253  			fe.bus.SendMonitorEvent(FabricEndpointUnsubscribeEvt, channelName, nil)
   254  		}
   255  	}
   256  }
   257  
   258  func (fe *fabricEndpoint) bridgeMessage(destination string, message []byte, connectionId string) {
   259  	var channelName string
   260  	isPrivateRequest := false
   261  
   262  	if fe.config.AppRequestQueuePrefix != "" && strings.HasPrefix(destination, fe.config.AppRequestQueuePrefix) {
   263  		channelName = destination[len(fe.config.AppRequestQueuePrefix):]
   264  		isPrivateRequest = true
   265  	} else if fe.config.AppRequestPrefix != "" && strings.HasPrefix(destination, fe.config.AppRequestPrefix) {
   266  		channelName = destination[len(fe.config.AppRequestPrefix):]
   267  	} else {
   268  		return
   269  	}
   270  
   271  	var req model.Request
   272  	err := json.Unmarshal(message, &req)
   273  	if err != nil {
   274  		log.Warn("Failed to deserialize request for channel %s", channelName)
   275  		return
   276  	}
   277  
   278  	if isPrivateRequest {
   279  		req.BrokerDestination = &model.BrokerDestinationConfig{
   280  			Destination:  fe.config.UserQueuePrefix + channelName,
   281  			ConnectionId: connectionId,
   282  		}
   283  	}
   284  
   285  	fe.bus.SendRequestMessage(channelName, &req, nil)
   286  }
   287  
   288  func (fe *fabricEndpoint) getChannelNameFromSubscription(destination string) (channelName string, ok bool) {
   289  	if strings.HasPrefix(destination, fe.config.TopicPrefix) {
   290  		return destination[len(fe.config.TopicPrefix):], true
   291  	}
   292  
   293  	if fe.config.UserQueuePrefix != "" && strings.HasPrefix(destination, fe.config.UserQueuePrefix) {
   294  		return destination[len(fe.config.UserQueuePrefix):], true
   295  	}
   296  	return "", false
   297  }
   298  
   299  // isProtectedDestination checks if the destination is protected. this utility function is used to
   300  // prevent messages being from clients to the protected destinations. such examples would be
   301  // internal bus channels prefixed with _transportInternal/
   302  func isProtectedDestination(destination string) bool {
   303  	return strings.HasPrefix(destination, TRANSPORT_INTERNAL_CHANNEL_PREFIX)
   304  }