github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/support_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ws
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"net/http"
    14  	"net/url"
    15  	"strconv"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/btcsuite/btcutil/base58"
    20  	"github.com/google/uuid"
    21  	"github.com/stretchr/testify/require"
    22  	"nhooyr.io/websocket"
    23  
    24  	"github.com/hyperledger/aries-framework-go/pkg/common/model"
    25  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    26  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    27  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    28  	"github.com/hyperledger/aries-framework-go/pkg/internal/test/transportutil"
    29  )
    30  
    31  type mockProvider struct {
    32  	packagerValue transport.Packager
    33  }
    34  
    35  func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler {
    36  	return func(envelope *transport.Envelope) error {
    37  		logger.Infof("message received is %s", string(envelope.Message))
    38  
    39  		if string(envelope.Message) == "invalid-data" {
    40  			return errors.New("error")
    41  		}
    42  
    43  		return nil
    44  	}
    45  }
    46  
    47  func (p *mockProvider) Packager() transport.Packager {
    48  	return p.packagerValue
    49  }
    50  
    51  func (p *mockProvider) AriesFrameworkID() string {
    52  	return uuid.New().String()
    53  }
    54  
    55  func websocketClient(t *testing.T, port string) (*websocket.Conn, func()) {
    56  	require.NoError(t, transportutil.VerifyListener("localhost"+port, time.Second))
    57  
    58  	u := url.URL{Scheme: "ws", Host: "localhost" + port, Path: ""}
    59  	c, resp, err := websocket.Dial(context.Background(), u.String(), nil) //nolint:bodyclose
    60  	require.NoError(t, err)
    61  	require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
    62  
    63  	return c, func() {
    64  		require.NoError(t, c.Close(websocket.StatusNormalClosure, "closing the connection"))
    65  	}
    66  }
    67  
    68  func prepareDestination(endPoint string) *service.Destination {
    69  	return &service.Destination{
    70  		ServiceEndpoint: model.NewDIDCommV1Endpoint(endPoint),
    71  	}
    72  }
    73  
    74  func prepareDestinationWithTransport(endPoint, returnRoute string,
    75  	recipientKeys, routingKeys []string) *service.Destination {
    76  	return &service.Destination{
    77  		ServiceEndpoint:      model.NewDIDCommV1Endpoint(endPoint),
    78  		RoutingKeys:          routingKeys,
    79  		RecipientKeys:        recipientKeys,
    80  		TransportReturnRoute: returnRoute,
    81  	}
    82  }
    83  
    84  func createTransportDecRequest(t *testing.T, transportReturnRoute string) []byte {
    85  	req := &decorator.Thread{
    86  		ID: uuid.New().String(),
    87  	}
    88  
    89  	outboundReq := struct {
    90  		*decorator.Transport
    91  		*decorator.Thread
    92  	}{
    93  		&decorator.Transport{ReturnRoute: &decorator.ReturnRoute{Value: transportReturnRoute}},
    94  		req,
    95  	}
    96  	request, err := json.Marshal(outboundReq)
    97  	require.NoError(t, err)
    98  	require.NotNil(t, request)
    99  
   100  	return request
   101  }
   102  
   103  func startWebSocketServer(t *testing.T, handlerFunc func(*testing.T, http.ResponseWriter, *http.Request)) string {
   104  	addr := "localhost:" + strconv.Itoa(transportutil.GetRandomPort(5))
   105  
   106  	server := &http.Server{Addr: addr}
   107  	server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   108  		handlerFunc(t, w, r)
   109  	})
   110  
   111  	go func() {
   112  		require.NoError(t, server.ListenAndServe())
   113  	}()
   114  
   115  	require.NoError(t, transportutil.VerifyListener(addr, time.Second))
   116  
   117  	return addr
   118  }
   119  
   120  func echo(t *testing.T, w http.ResponseWriter, r *http.Request) {
   121  	c, err := Accept(w, r)
   122  	require.NoError(t, err)
   123  
   124  	defer func() {
   125  		require.NoError(t, c.Close(websocket.StatusNormalClosure, "closing the connection"))
   126  	}()
   127  
   128  	ctx := context.Background()
   129  
   130  	for {
   131  		mt, message, err := c.Read(ctx)
   132  		if err != nil {
   133  			break
   134  		}
   135  
   136  		logger.Infof("r: %s", message)
   137  
   138  		err = c.Write(ctx, mt, message)
   139  		require.NoError(t, err)
   140  	}
   141  }
   142  
   143  // mockPackager mock packager.
   144  type mockPackager struct {
   145  	verKey string
   146  }
   147  
   148  func (m *mockPackager) PackMessage(e *transport.Envelope) ([]byte, error) {
   149  	return e.Message, nil
   150  }
   151  
   152  func (m *mockPackager) UnpackMessage(encMessage []byte) (*transport.Envelope, error) {
   153  	return &transport.Envelope{Message: encMessage, FromKey: base58.Decode(m.verKey)}, nil
   154  }
   155  
   156  type mockTransportProvider struct {
   157  	packagerValue  transport.Packager
   158  	executeInbound func(envelope *transport.Envelope) error
   159  	frameworkID    string
   160  }
   161  
   162  func (p *mockTransportProvider) InboundMessageHandler() transport.InboundMessageHandler {
   163  	return p.executeInbound
   164  }
   165  
   166  func (p *mockTransportProvider) Packager() transport.Packager {
   167  	return p.packagerValue
   168  }
   169  
   170  func (p *mockTransportProvider) AriesFrameworkID() string {
   171  	if p.frameworkID != "" {
   172  		return p.frameworkID
   173  	}
   174  
   175  	return "framework-instance-1"
   176  }