github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/inbound_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ws
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"strconv"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/stretchr/testify/require"
    18  	"nhooyr.io/websocket"
    19  
    20  	cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    23  	"github.com/hyperledger/aries-framework-go/pkg/internal/test/transportutil"
    24  	mockpackager "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/packager"
    25  )
    26  
    27  const defaultReadLimit = 32768
    28  
    29  func TestInboundTransport(t *testing.T) {
    30  	t.Run("test inbound transport - with host/port", func(t *testing.T) {
    31  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
    32  		externalAddr := "http://example.com" + port
    33  		inbound, err := NewInbound("localhost"+port, externalAddr, "", "")
    34  		require.NoError(t, err)
    35  		require.Equal(t, externalAddr, inbound.Endpoint())
    36  	})
    37  
    38  	t.Run("test inbound transport - with host/port, no external address", func(t *testing.T) {
    39  		internalAddr := "example.com" + ":" + strconv.Itoa(transportutil.GetRandomPort(5))
    40  		inbound, err := NewInbound(internalAddr, "", "", "")
    41  		require.NoError(t, err)
    42  		require.Equal(t, internalAddr, inbound.Endpoint())
    43  	})
    44  
    45  	t.Run("test inbound transport - without host/port", func(t *testing.T) {
    46  		inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "", "", "")
    47  		require.NoError(t, err)
    48  		require.NotEmpty(t, inbound)
    49  		mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("data")}}
    50  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
    51  		require.NoError(t, err)
    52  
    53  		err = inbound.Stop()
    54  		require.NoError(t, err)
    55  	})
    56  
    57  	t.Run("test inbound transport - nil context", func(t *testing.T) {
    58  		inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "", "", "")
    59  		require.NoError(t, err)
    60  		require.NotEmpty(t, inbound)
    61  
    62  		err = inbound.Start(nil)
    63  		require.Error(t, err)
    64  	})
    65  
    66  	t.Run("test inbound transport - invalid TLS", func(t *testing.T) {
    67  		svc, err := NewInbound(":0", "", "invalid", "invalid")
    68  		require.NoError(t, err)
    69  
    70  		err = svc.listenAndServe()
    71  		require.Error(t, err)
    72  		require.Contains(t, err.Error(), "open invalid: no such file or directory")
    73  	})
    74  
    75  	t.Run("test inbound transport - invalid port number", func(t *testing.T) {
    76  		_, err := NewInbound("", "", "", "")
    77  		require.Error(t, err)
    78  		require.Contains(t, err.Error(), "websocket address is mandatory")
    79  	})
    80  }
    81  
    82  func TestInboundDataProcessing(t *testing.T) {
    83  	t.Run("test inbound transport - multiple invocation with same client", func(t *testing.T) {
    84  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
    85  
    86  		// initiate inbound with port
    87  		inbound, err := NewInbound(port, "", "", "")
    88  		require.NoError(t, err)
    89  		require.NotEmpty(t, inbound)
    90  
    91  		// start server
    92  		mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("valid-data")}}
    93  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
    94  		require.NoError(t, err)
    95  
    96  		// create ws client
    97  		client, cleanup := websocketClient(t, port)
    98  		defer cleanup()
    99  
   100  		ctx := context.Background()
   101  
   102  		for i := 1; i <= 5; i++ {
   103  			err = client.Write(ctx, websocket.MessageText, []byte("random"))
   104  			require.NoError(t, err)
   105  		}
   106  	})
   107  
   108  	t.Run("test inbound transport - unpacking error", func(t *testing.T) {
   109  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
   110  
   111  		// initiate inbound with port
   112  		inbound, err := NewInbound(port, "", "", "")
   113  		require.NoError(t, err)
   114  		require.NotEmpty(t, inbound)
   115  
   116  		// start server
   117  		mockPackager := &mockpackager.Packager{UnpackErr: errors.New("error unpacking")}
   118  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
   119  		require.NoError(t, err)
   120  
   121  		// create ws client
   122  		client, cleanup := websocketClient(t, port)
   123  		defer cleanup()
   124  
   125  		ctx := context.Background()
   126  
   127  		err = client.Write(ctx, websocket.MessageText, []byte(""))
   128  		require.NoError(t, err)
   129  	})
   130  
   131  	t.Run("test inbound transport - message handler error", func(t *testing.T) {
   132  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
   133  
   134  		// initiate inbound with port
   135  		inbound, err := NewInbound(port, "", "", "")
   136  		require.NoError(t, err)
   137  		require.NotEmpty(t, inbound)
   138  
   139  		// start server
   140  		mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("invalid-data")}}
   141  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
   142  		require.NoError(t, err)
   143  
   144  		// create ws client
   145  		client, cleanup := websocketClient(t, port)
   146  		defer cleanup()
   147  
   148  		ctx := context.Background()
   149  
   150  		err = client.Write(ctx, websocket.MessageText, []byte(""))
   151  		require.NoError(t, err)
   152  	})
   153  
   154  	t.Run("test inbound transport - client close error", func(t *testing.T) {
   155  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
   156  
   157  		// initiate inbound with port
   158  		inbound, err := NewInbound(port, "", "", "")
   159  		require.NoError(t, err)
   160  		require.NotEmpty(t, inbound)
   161  
   162  		// start server
   163  		mockPackager := &mockpackager.Packager{}
   164  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
   165  		require.NoError(t, err)
   166  
   167  		// create ws client
   168  		client, _ := websocketClient(t, port)
   169  
   170  		err = client.Close(websocket.StatusInternalError, "abnormal closure")
   171  		require.NoError(t, err)
   172  	})
   173  
   174  	t.Run("test inbound transport - custom read limit for a single message", func(t *testing.T) {
   175  		port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
   176  
   177  		// initiate inbound with a port and a custom read limit
   178  		inbound, err := NewInbound(port, "", "", "", WithInboundReadLimit(defaultReadLimit+1))
   179  		require.NoError(t, err)
   180  		require.NotEmpty(t, inbound)
   181  
   182  		trans := &decorator.Transport{
   183  			ReturnRoute: &decorator.ReturnRoute{
   184  				Value: decorator.TransportReturnRouteNone,
   185  			},
   186  		}
   187  
   188  		unpackMsg, err := json.Marshal(trans)
   189  		require.NoError(t, err)
   190  
   191  		fromKey, err := json.Marshal(&cryptoapi.PublicKey{KID: "keyID"})
   192  		require.NoError(t, err)
   193  
   194  		mockPackager := &mockpackager.Packager{
   195  			UnpackValue: &transport.Envelope{
   196  				Message: unpackMsg,
   197  				FromKey: fromKey,
   198  			},
   199  		}
   200  
   201  		done := make(chan struct{})
   202  
   203  		// start server
   204  		err = inbound.Start(&mockTransportProvider{
   205  			packagerValue: mockPackager,
   206  			executeInbound: func(envelope *transport.Envelope) error {
   207  				done <- struct{}{}
   208  				return nil
   209  			},
   210  		})
   211  		require.NoError(t, err)
   212  
   213  		// create ws client
   214  		client, cleanup := websocketClient(t, port)
   215  		defer cleanup()
   216  
   217  		msg := make([]byte, defaultReadLimit+1)
   218  
   219  		err = client.Write(context.Background(), websocket.MessageText, msg)
   220  		require.NoError(t, err)
   221  
   222  		select {
   223  		case <-done:
   224  		case <-time.After(3 * time.Second):
   225  			require.Fail(t, "inbound message handler was not called within given timeout")
   226  		}
   227  	})
   228  }