github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/inbound.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ws
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net/http"
    14  
    15  	"nhooyr.io/websocket"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    19  )
    20  
    21  var logger = log.New("aries-framework/ws")
    22  
    23  type inboundOpts struct {
    24  	readLimit int64
    25  }
    26  
    27  // InboundOpt is an inbound ws option.
    28  type InboundOpt func(opts *inboundOpts)
    29  
    30  // WithInboundReadLimit sets the custom max number of bytes to read for a single message.
    31  func WithInboundReadLimit(n int64) InboundOpt {
    32  	return func(opts *inboundOpts) {
    33  		opts.readLimit = n
    34  	}
    35  }
    36  
    37  // Inbound http(ws) type.
    38  type Inbound struct {
    39  	externalAddr      string
    40  	server            *http.Server
    41  	pool              *connPool
    42  	certFile, keyFile string
    43  	readLimit         int64
    44  }
    45  
    46  // NewInbound creates a new WebSocket inbound transport instance.
    47  func NewInbound(internalAddr, externalAddr, certFile, keyFile string, opts ...InboundOpt) (*Inbound, error) {
    48  	inOpts := &inboundOpts{}
    49  
    50  	for _, opt := range opts {
    51  		opt(inOpts)
    52  	}
    53  
    54  	if internalAddr == "" {
    55  		return nil, errors.New("websocket address is mandatory")
    56  	}
    57  
    58  	if externalAddr == "" {
    59  		externalAddr = internalAddr
    60  	}
    61  
    62  	return &Inbound{
    63  		certFile:     certFile,
    64  		keyFile:      keyFile,
    65  		externalAddr: externalAddr,
    66  		server:       &http.Server{Addr: internalAddr},
    67  		readLimit:    inOpts.readLimit,
    68  	}, nil
    69  }
    70  
    71  // Start the http(ws) server.
    72  func (i *Inbound) Start(prov transport.Provider) error {
    73  	if prov == nil || prov.InboundMessageHandler() == nil {
    74  		return errors.New("creation of inbound handler failed")
    75  	}
    76  
    77  	i.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  		i.processRequest(w, r)
    79  	})
    80  
    81  	i.pool = getConnPool(prov)
    82  
    83  	go func() {
    84  		if err := i.listenAndServe(); !errors.Is(err, http.ErrServerClosed) {
    85  			logger.Fatalf("websocket server start with address [%s] failed, cause:  %s", i.server.Addr, err)
    86  		}
    87  	}()
    88  
    89  	return nil
    90  }
    91  
    92  func (i *Inbound) listenAndServe() error {
    93  	if i.certFile != "" && i.keyFile != "" {
    94  		return i.server.ListenAndServeTLS(i.certFile, i.keyFile)
    95  	}
    96  
    97  	return i.server.ListenAndServe()
    98  }
    99  
   100  // Stop the http(ws) server.
   101  func (i *Inbound) Stop() error {
   102  	if err := i.server.Shutdown(context.Background()); err != nil {
   103  		return fmt.Errorf("websocket server shutdown failed: %w", err)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  // Endpoint provides the http(ws) connection details.
   110  func (i *Inbound) Endpoint() string {
   111  	return i.externalAddr
   112  }
   113  
   114  func (i *Inbound) processRequest(w http.ResponseWriter, r *http.Request) {
   115  	c, err := upgradeConnection(w, r)
   116  	if err != nil {
   117  		logger.Errorf("failed to upgrade the connection : %v", err)
   118  		return
   119  	}
   120  
   121  	if i.readLimit > 0 {
   122  		c.SetReadLimit(i.readLimit)
   123  	}
   124  
   125  	i.pool.listener(c, false)
   126  }
   127  
   128  func upgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
   129  	c, err := Accept(w, r)
   130  	if err != nil {
   131  		logger.Errorf("failed to upgrade the connection : %v", err)
   132  		return nil, err
   133  	}
   134  
   135  	return c, nil
   136  }