github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/inbound.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package ws 8 9 import ( 10 "context" 11 "errors" 12 "fmt" 13 "net/http" 14 15 "nhooyr.io/websocket" 16 17 "github.com/hyperledger/aries-framework-go/pkg/common/log" 18 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" 19 ) 20 21 var logger = log.New("aries-framework/ws") 22 23 type inboundOpts struct { 24 readLimit int64 25 } 26 27 // InboundOpt is an inbound ws option. 28 type InboundOpt func(opts *inboundOpts) 29 30 // WithInboundReadLimit sets the custom max number of bytes to read for a single message. 31 func WithInboundReadLimit(n int64) InboundOpt { 32 return func(opts *inboundOpts) { 33 opts.readLimit = n 34 } 35 } 36 37 // Inbound http(ws) type. 38 type Inbound struct { 39 externalAddr string 40 server *http.Server 41 pool *connPool 42 certFile, keyFile string 43 readLimit int64 44 } 45 46 // NewInbound creates a new WebSocket inbound transport instance. 47 func NewInbound(internalAddr, externalAddr, certFile, keyFile string, opts ...InboundOpt) (*Inbound, error) { 48 inOpts := &inboundOpts{} 49 50 for _, opt := range opts { 51 opt(inOpts) 52 } 53 54 if internalAddr == "" { 55 return nil, errors.New("websocket address is mandatory") 56 } 57 58 if externalAddr == "" { 59 externalAddr = internalAddr 60 } 61 62 return &Inbound{ 63 certFile: certFile, 64 keyFile: keyFile, 65 externalAddr: externalAddr, 66 server: &http.Server{Addr: internalAddr}, 67 readLimit: inOpts.readLimit, 68 }, nil 69 } 70 71 // Start the http(ws) server. 72 func (i *Inbound) Start(prov transport.Provider) error { 73 if prov == nil || prov.InboundMessageHandler() == nil { 74 return errors.New("creation of inbound handler failed") 75 } 76 77 i.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 i.processRequest(w, r) 79 }) 80 81 i.pool = getConnPool(prov) 82 83 go func() { 84 if err := i.listenAndServe(); !errors.Is(err, http.ErrServerClosed) { 85 logger.Fatalf("websocket server start with address [%s] failed, cause: %s", i.server.Addr, err) 86 } 87 }() 88 89 return nil 90 } 91 92 func (i *Inbound) listenAndServe() error { 93 if i.certFile != "" && i.keyFile != "" { 94 return i.server.ListenAndServeTLS(i.certFile, i.keyFile) 95 } 96 97 return i.server.ListenAndServe() 98 } 99 100 // Stop the http(ws) server. 101 func (i *Inbound) Stop() error { 102 if err := i.server.Shutdown(context.Background()); err != nil { 103 return fmt.Errorf("websocket server shutdown failed: %w", err) 104 } 105 106 return nil 107 } 108 109 // Endpoint provides the http(ws) connection details. 110 func (i *Inbound) Endpoint() string { 111 return i.externalAddr 112 } 113 114 func (i *Inbound) processRequest(w http.ResponseWriter, r *http.Request) { 115 c, err := upgradeConnection(w, r) 116 if err != nil { 117 logger.Errorf("failed to upgrade the connection : %v", err) 118 return 119 } 120 121 if i.readLimit > 0 { 122 c.SetReadLimit(i.readLimit) 123 } 124 125 i.pool.listener(c, false) 126 } 127 128 func upgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { 129 c, err := Accept(w, r) 130 if err != nil { 131 logger.Errorf("failed to upgrade the connection : %v", err) 132 return nil, err 133 } 134 135 return c, nil 136 }