github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/http/inbound_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package http
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"errors"
    14  	"fmt"
    15  	"io/ioutil"
    16  	"net"
    17  	"net/http"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    24  	mockpackager "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/packager"
    25  )
    26  
    27  type mockProvider struct {
    28  	packagerValue transport.Packager
    29  }
    30  
    31  func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler {
    32  	return func(envelope *transport.Envelope) error {
    33  		logger.Debugf("message received is %s", envelope.Message)
    34  		return nil
    35  	}
    36  }
    37  
    38  func (p *mockProvider) Packager() transport.Packager {
    39  	return p.packagerValue
    40  }
    41  
    42  func (p *mockProvider) AriesFrameworkID() string {
    43  	return "aries-framework-instance-1"
    44  }
    45  
    46  func TestInboundHandler(t *testing.T) {
    47  	// test inboundHandler with empty args should fail
    48  	inHandler, err := NewInboundHandler(nil)
    49  	require.Error(t, err)
    50  	require.Nil(t, inHandler)
    51  
    52  	mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("data")}}
    53  
    54  	// now create a valid inboundHandler to continue testing..
    55  	inHandler, err = NewInboundHandler(&mockProvider{packagerValue: mockPackager})
    56  	require.NoError(t, err)
    57  	require.NotNil(t, inHandler)
    58  
    59  	server := startMockServer(inHandler)
    60  	port := getServerPort(server)
    61  	serverURL := fmt.Sprintf("https://localhost:%d", port)
    62  
    63  	defer func() {
    64  		e := server.Close()
    65  		if e != nil {
    66  			t.Fatalf("Failed to stop server: %s", e)
    67  		}
    68  	}()
    69  
    70  	// build a mock cert pool
    71  	cp := x509.NewCertPool()
    72  	err = addCertsToCertPool(cp)
    73  	require.NoError(t, err)
    74  
    75  	// build a tls.Config instance to be used by the outbound transport
    76  	tlsConfig := &tls.Config{ //nolint:gosec
    77  		RootCAs:      cp,
    78  		Certificates: nil,
    79  	}
    80  
    81  	// create an http client to communicate with the server that has our inbound handlers set above
    82  	client := http.Client{
    83  		Timeout: clientTimeout,
    84  		Transport: &http.Transport{
    85  			TLSClientConfig: tlsConfig,
    86  		},
    87  	}
    88  
    89  	// test http.Get should should fail (not supported)
    90  	rs, err := client.Get(serverURL + "/")
    91  	require.NoError(t, err)
    92  	err = rs.Body.Close()
    93  	require.NoError(t, err)
    94  	require.Equal(t, http.StatusMethodNotAllowed, rs.StatusCode)
    95  
    96  	// test accepted HTTP method (POST) but with bad content type
    97  	rs, err = client.Post(serverURL+"/", "bad-content-type", bytes.NewBuffer([]byte("Hello World")))
    98  	require.NoError(t, err)
    99  	err = rs.Body.Close()
   100  	require.NoError(t, err)
   101  	require.Equal(t, http.StatusUnsupportedMediaType, rs.StatusCode)
   102  
   103  	contentTypes := []string{commContentType, commContentTypeLegacy}
   104  	data := "success"
   105  
   106  	for _, contentType := range contentTypes {
   107  		// test with nil body ..
   108  		resp, err := client.Post(serverURL+"/", contentType, nil)
   109  		require.NoError(t, err)
   110  		require.NoError(t, err)
   111  		require.Equal(t, http.StatusBadRequest, resp.StatusCode)
   112  		require.NoError(t, resp.Body.Close())
   113  
   114  		// test successful POST requests
   115  		resp, err = client.Post(serverURL+"/", contentType, bytes.NewBuffer([]byte(data)))
   116  		require.NoError(t, err)
   117  		err = resp.Body.Close()
   118  		require.NoError(t, err)
   119  		require.NotNil(t, resp)
   120  		require.Equal(t, http.StatusAccepted, resp.StatusCode)
   121  	}
   122  
   123  	// test unpack error
   124  	mockPackager.UnpackValue = nil
   125  	mockPackager.UnpackErr = fmt.Errorf("unpack error")
   126  
   127  	for _, contentType := range contentTypes {
   128  		resp, err := client.Post(serverURL+"/", contentType, bytes.NewBuffer([]byte(data)))
   129  		require.NoError(t, err)
   130  		require.NotNil(t, resp)
   131  		require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   132  		body, err := ioutil.ReadAll(resp.Body)
   133  		require.NoError(t, err)
   134  		require.Contains(t, string(body), "failed to unpack msg")
   135  		require.NoError(t, resp.Body.Close())
   136  	}
   137  }
   138  
   139  func TestInboundTransport(t *testing.T) {
   140  	t.Run("test inbound transport - with host/port", func(t *testing.T) {
   141  		port := "26601"
   142  		externalAddr := "http://example.com:" + port
   143  		inbound, err := NewInbound("localhost:"+port, externalAddr, "", "")
   144  		require.NoError(t, err)
   145  		require.Equal(t, externalAddr, inbound.Endpoint())
   146  	})
   147  
   148  	t.Run("test inbound transport - with host/port, no external address", func(t *testing.T) {
   149  		internalAddr := "example.com:26602"
   150  		inbound, err := NewInbound(internalAddr, "", "", "")
   151  		require.NoError(t, err)
   152  		require.Equal(t, internalAddr, inbound.Endpoint())
   153  	})
   154  
   155  	t.Run("test inbound transport - without host/port", func(t *testing.T) {
   156  		inbound, err := NewInbound(":26603", "", "", "")
   157  		require.NoError(t, err)
   158  		require.NotEmpty(t, inbound)
   159  		mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("data")}}
   160  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
   161  		require.NoError(t, err)
   162  
   163  		err = inbound.Stop()
   164  		require.NoError(t, err)
   165  	})
   166  
   167  	t.Run("test inbound transport - nil context", func(t *testing.T) {
   168  		inbound, err := NewInbound(":26604", "", "", "")
   169  		require.NoError(t, err)
   170  		require.NotEmpty(t, inbound)
   171  
   172  		err = inbound.Start(nil)
   173  		require.Error(t, err)
   174  	})
   175  
   176  	t.Run("test inbound transport - invalid port number", func(t *testing.T) {
   177  		_, err := NewInbound("", "", "", "")
   178  		require.Error(t, err)
   179  		require.Contains(t, err.Error(), "http address is mandatory")
   180  	})
   181  
   182  	t.Run("test inbound transport - invalid TLS", func(t *testing.T) {
   183  		svc, err := NewInbound(":0", "", "invalid", "invalid")
   184  		require.NoError(t, err)
   185  
   186  		err = svc.listenAndServe()
   187  		require.Error(t, err)
   188  		require.Contains(t, err.Error(), "open invalid: no such file or directory")
   189  	})
   190  
   191  	t.Run("test inbound transport - invoke endpoint", func(t *testing.T) {
   192  		// initiate inbound with port
   193  		inbound, err := NewInbound(":26605", "", "", "")
   194  		require.NoError(t, err)
   195  		require.NotEmpty(t, inbound)
   196  
   197  		// start server
   198  		mockPackager := &mockpackager.Packager{UnpackValue: &transport.Envelope{Message: []byte("data")}}
   199  		err = inbound.Start(&mockProvider{packagerValue: mockPackager})
   200  		require.NoError(t, err)
   201  		require.NoError(t, listenFor("localhost:26605", time.Second))
   202  
   203  		contentTypes := []string{commContentType, commContentTypeLegacy}
   204  		client := http.Client{}
   205  
   206  		for _, contentType := range contentTypes {
   207  			// invoke a endpoint
   208  			var resp *http.Response
   209  			resp, err = client.Post("http://localhost:26605", contentType, bytes.NewBuffer([]byte("success")))
   210  			require.NoError(t, err)
   211  			require.Equal(t, http.StatusAccepted, resp.StatusCode)
   212  			require.NotNil(t, resp)
   213  
   214  			err = resp.Body.Close()
   215  			require.NoError(t, err)
   216  		}
   217  
   218  		// stop server
   219  		err = inbound.Stop()
   220  		require.NoError(t, err)
   221  
   222  		// try after server stop
   223  		for _, contentType := range contentTypes {
   224  			_, err = client.Post("http://localhost:26605", contentType, bytes.NewBuffer([]byte("success"))) // nolint
   225  			require.Error(t, err)
   226  		}
   227  	})
   228  }
   229  
   230  func listenFor(host string, d time.Duration) error {
   231  	timeout := time.After(d)
   232  
   233  	for {
   234  		select {
   235  		case <-timeout:
   236  			return errors.New("timeout: server is not available")
   237  		default:
   238  			conn, err := net.Dial("tcp", host)
   239  			if err != nil {
   240  				continue
   241  			}
   242  
   243  			return conn.Close()
   244  		}
   245  	}
   246  }