github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/outbound.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ws
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"strings"
    13  
    14  	"nhooyr.io/websocket"
    15  
    16  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    17  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    19  )
    20  
    21  const webSocketScheme = "ws"
    22  
    23  // OutboundClient websocket outbound.
    24  type OutboundClient struct {
    25  	pool      *connPool
    26  	prov      transport.Provider
    27  	readLimit int64
    28  }
    29  
    30  // OutboundClientOpt configures outbound client.
    31  type OutboundClientOpt func(c *OutboundClient)
    32  
    33  // WithOutboundReadLimit sets the custom max number of bytes to read for a single message.
    34  func WithOutboundReadLimit(n int64) OutboundClientOpt {
    35  	return func(c *OutboundClient) {
    36  		c.readLimit = n
    37  	}
    38  }
    39  
    40  // NewOutbound creates a client for Outbound WS transport.
    41  func NewOutbound(opts ...OutboundClientOpt) *OutboundClient {
    42  	c := &OutboundClient{}
    43  
    44  	for _, opt := range opts {
    45  		opt(c)
    46  	}
    47  
    48  	return c
    49  }
    50  
    51  // Start starts the outbound transport.
    52  func (cs *OutboundClient) Start(prov transport.Provider) error {
    53  	cs.pool = getConnPool(prov)
    54  	cs.prov = prov
    55  
    56  	return nil
    57  }
    58  
    59  // Send sends a2a data via WS.
    60  func (cs *OutboundClient) Send(data []byte, destination *service.Destination) (string, error) {
    61  	conn, cleanup, err := cs.getConnection(destination)
    62  	defer cleanup()
    63  
    64  	if err != nil {
    65  		return "", fmt.Errorf("get websocket connection : %w", err)
    66  	}
    67  
    68  	err = conn.Write(context.Background(), websocket.MessageText, data)
    69  	if err != nil {
    70  		logger.Errorf("didcomm failed : transport=ws serviceEndpoint=%s errMsg=%s",
    71  			destination.ServiceEndpoint, err.Error())
    72  
    73  		return "", fmt.Errorf("websocket write message : %w", err)
    74  	}
    75  
    76  	return "", nil
    77  }
    78  
    79  // Accept checks for the url scheme.
    80  func (cs *OutboundClient) Accept(url string) bool {
    81  	return strings.HasPrefix(url, webSocketScheme)
    82  }
    83  
    84  // AcceptRecipient checks if there is a connection for the list of recipient keys.
    85  func (cs *OutboundClient) AcceptRecipient(keys []string) bool {
    86  	return acceptRecipient(cs.pool, keys)
    87  }
    88  
    89  //nolint:gocyclo,funlen
    90  func (cs *OutboundClient) getConnection(destination *service.Destination) (*websocket.Conn, func(), error) {
    91  	var conn *websocket.Conn
    92  
    93  	// get the connection for the routing or recipient keys
    94  	keys := destination.RecipientKeys
    95  	if routingKeys, err := destination.ServiceEndpoint.RoutingKeys(); err == nil && len(routingKeys) != 0 {
    96  		keys = routingKeys
    97  	} else if len(destination.RoutingKeys) != 0 {
    98  		keys = destination.RoutingKeys
    99  	}
   100  
   101  	for _, v := range keys {
   102  		if c := cs.pool.fetch(v); c != nil {
   103  			conn = c
   104  
   105  			break
   106  		}
   107  	}
   108  
   109  	cleanup := func() {}
   110  
   111  	if conn != nil {
   112  		return conn, cleanup, nil
   113  	}
   114  
   115  	var (
   116  		err error
   117  		uri string
   118  	)
   119  
   120  	uri, err = destination.ServiceEndpoint.URI()
   121  	if err != nil {
   122  		return nil, cleanup, fmt.Errorf("unable to send ws outbound request: %w", err)
   123  	}
   124  
   125  	conn, _, err = websocket.Dial(context.Background(), uri, nil)
   126  	if err != nil {
   127  		return nil, cleanup, fmt.Errorf("websocket client : %w", err)
   128  	}
   129  
   130  	if cs.readLimit > 0 {
   131  		conn.SetReadLimit(cs.readLimit)
   132  	}
   133  
   134  	// keep the connection open to listen to the response in case of return route option set
   135  	if destination.TransportReturnRoute == decorator.TransportReturnRouteAll {
   136  		for _, v := range destination.RecipientKeys {
   137  			cs.pool.add(v, conn)
   138  		}
   139  
   140  		go cs.pool.listener(conn, true)
   141  
   142  		return conn, cleanup, nil
   143  	}
   144  
   145  	cleanup = func() {
   146  		err = conn.Close(websocket.StatusNormalClosure, "closing the connection")
   147  		if err != nil && websocket.CloseStatus(err) != websocket.StatusNormalClosure {
   148  			logger.Errorf("failed to close connection: %v", err)
   149  		}
   150  	}
   151  
   152  	return conn, cleanup, nil
   153  }