github.com/vmware/transport-go@v1.3.4/bridge/broker_connector_tls_test.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package bridge
     5  
     6  import (
     7  	"crypto/tls"
     8  	"fmt"
     9  	"github.com/go-stomp/stomp/v3/frame"
    10  	"github.com/go-stomp/stomp/v3/server"
    11  	"github.com/stretchr/testify/assert"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"testing"
    18  )
    19  
    20  var webSocketURLChanTLS = make(chan string)
    21  var websocketURLTLS string
    22  
    23  //var srv Server
    24  var testTLS = &tls.Config{
    25  	InsecureSkipVerify: true,
    26  	MinVersion:         tls.VersionTLS12,
    27  	CipherSuites: []uint16{
    28  		tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
    29  		tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
    30  		tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
    31  		tls.TLS_RSA_WITH_AES_256_CBC_SHA,
    32  	},
    33  }
    34  
    35  func runWebSocketEndPointTLS() {
    36  	s := httptest.NewUnstartedServer(http.HandlerFunc(websocketHandler))
    37  	s.TLS = testTLS
    38  	s.StartTLS()
    39  	log.Println("WebSocket listening on", s.Listener.Addr().Network(), s.Listener.Addr().String(), "(TLS)")
    40  	httpServer = s
    41  	webSocketURLChanTLS <- s.URL
    42  }
    43  
    44  func runStompBrokerTLS() {
    45  	l, err := net.Listen("tcp", ":51582")
    46  	if err != nil {
    47  		log.Fatalf("failed to listen: %s", err.Error())
    48  	}
    49  	defer func() { l.Close() }()
    50  
    51  	log.Println("TCP listening on", l.Addr().Network(), l.Addr().String(), "(TLS)")
    52  	server.Serve(l)
    53  	tcpServer = l
    54  }
    55  
    56  func init() {
    57  	go runStompBrokerTLS()
    58  	go runWebSocketEndPointTLS()
    59  
    60  	websocketURLTLS = <-webSocketURLChanTLS
    61  }
    62  
    63  func TestBrokerConnector_ConnectBroker_Invalid_TLS_Cert(t *testing.T) {
    64  	url, _ := url.Parse(websocketURLTLS)
    65  	host, port, _ := net.SplitHostPort(url.Host)
    66  	testHost := host + ":" + port
    67  
    68  	brokerConfig := &BrokerConnectorConfig{
    69  		Username: "guest",
    70  		Password: "guest",
    71  		UseWS:    true,
    72  		WebSocketConfig: &WebSocketConfig{
    73  			WSPath:    "/fabric",
    74  			UseTLS:    true,
    75  			TLSConfig: testTLS,
    76  			CertFile:  "nothing",
    77  			KeyFile:   "nothing",
    78  		},
    79  		ServerAddr: testHost,
    80  	}
    81  	bc := NewBrokerConnector()
    82  	_, err := bc.Connect(brokerConfig, true)
    83  
    84  	assert.NotNil(t, err)
    85  }
    86  
    87  func TestBrokerConnector_ConnectBroker_TLS(t *testing.T) {
    88  	url, _ := url.Parse(websocketURLTLS)
    89  	host, port, _ := net.SplitHostPort(url.Host)
    90  	testHost := host + ":" + port
    91  
    92  	tt := []struct {
    93  		test   string
    94  		config *BrokerConnectorConfig
    95  	}{
    96  		{
    97  			"Connect via websocket with TLS",
    98  			&BrokerConnectorConfig{
    99  				Username: "guest",
   100  				Password: "guest",
   101  				WebSocketConfig: &WebSocketConfig{
   102  					WSPath:    "/",
   103  					UseTLS:    true,
   104  					CertFile:  "test_server.crt",
   105  					KeyFile:   "test_server.key",
   106  					TLSConfig: testTLS,
   107  				},
   108  				UseWS: true,
   109  				STOMPHeader: map[string]string{
   110  					"access-token": "test",
   111  				},
   112  				ServerAddr: testHost},
   113  		},
   114  	}
   115  
   116  	for _, tc := range tt {
   117  		t.Run(tc.test, func(t *testing.T) {
   118  
   119  			// connect
   120  			bc := NewBrokerConnector()
   121  			c, err := bc.Connect(tc.config, true)
   122  
   123  			if err != nil {
   124  				fmt.Printf("unable to connect, error: %e", err)
   125  			}
   126  
   127  			assert.NotNil(t, c)
   128  			assert.Nil(t, err)
   129  			if tc.config.UseWS {
   130  				assert.NotNil(t, c.(*connection).wsConn)
   131  			}
   132  			if !tc.config.UseWS {
   133  				assert.NotNil(t, c.(*connection).conn)
   134  			}
   135  
   136  			m, _ := c.Subscribe("/topic/test-topic")
   137  			go func() {
   138  				err = c.SendJSONMessage("/topic/test-topic", []byte("{}"), func(frame *frame.Frame) error {
   139  					frame.Header.Set("access-token", "test")
   140  					return nil
   141  				})
   142  				assert.Nil(t, err)
   143  			}()
   144  			msg := <-m.GetMsgChannel()
   145  			b := msg.Payload.([]byte)
   146  			assert.EqualValues(t, "happy baby melody!", string(b))
   147  
   148  			// disconnect
   149  			err = c.Disconnect()
   150  			assert.Nil(t, err)
   151  			if tc.config.UseWS {
   152  				assert.Nil(t, c.(*connection).wsConn)
   153  			}
   154  			if !tc.config.UseWS {
   155  				assert.Nil(t, c.(*connection).conn)
   156  			}
   157  		})
   158  	}
   159  }