github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/events/websockets/websockets.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package websockets
    18  
    19  import (
    20  	"context"
    21  	"net/http"
    22  	"sync"
    23  
    24  	"github.com/gorilla/websocket"
    25  	"github.com/kaleido-io/firefly/internal/config"
    26  	"github.com/kaleido-io/firefly/internal/i18n"
    27  	"github.com/kaleido-io/firefly/internal/log"
    28  	"github.com/kaleido-io/firefly/pkg/events"
    29  	"github.com/kaleido-io/firefly/pkg/fftypes"
    30  )
    31  
    32  type WebSockets struct {
    33  	ctx          context.Context
    34  	capabilities *events.Capabilities
    35  	callbacks    events.Callbacks
    36  	connections  map[string]*websocketConnection
    37  	connMux      sync.Mutex
    38  	upgrader     websocket.Upgrader
    39  }
    40  
    41  func (ws *WebSockets) Name() string { return "websockets" }
    42  
    43  func (ws *WebSockets) Init(ctx context.Context, prefix config.Prefix, callbacks events.Callbacks) error {
    44  	*ws = WebSockets{
    45  		ctx:          ctx,
    46  		connections:  make(map[string]*websocketConnection),
    47  		capabilities: &events.Capabilities{},
    48  		callbacks:    callbacks,
    49  		upgrader: websocket.Upgrader{
    50  			ReadBufferSize:  int(prefix.GetByteSize(ReadBufferSize)),
    51  			WriteBufferSize: int(prefix.GetByteSize(WriteBufferSize)),
    52  			CheckOrigin: func(r *http.Request) bool {
    53  				// Cors is handled by the API server that wraps this handler
    54  				return true
    55  			},
    56  		},
    57  	}
    58  	return nil
    59  }
    60  
    61  func (ws *WebSockets) Capabilities() *events.Capabilities {
    62  	return ws.capabilities
    63  }
    64  
    65  func (ws *WebSockets) DeliveryRequest(connID string, event *fftypes.EventDelivery) error {
    66  	ws.connMux.Lock()
    67  	conn, ok := ws.connections[connID]
    68  	ws.connMux.Unlock()
    69  	if !ok {
    70  		return i18n.NewError(ws.ctx, i18n.MsgWSConnectionNotActive, connID)
    71  	}
    72  	return conn.dispatch(event)
    73  }
    74  
    75  func (ws *WebSockets) ServeHTTP(res http.ResponseWriter, req *http.Request) {
    76  	wsConn, err := ws.upgrader.Upgrade(res, req, nil)
    77  	if err != nil {
    78  		log.L(ws.ctx).Errorf("WebSocket upgrade failed: %s", err)
    79  		return
    80  	}
    81  
    82  	ws.connMux.Lock()
    83  	wc := newConnection(ws.ctx, ws, wsConn)
    84  	ws.connections[wc.connID] = wc
    85  	ws.connMux.Unlock()
    86  
    87  	wc.processAutoStart(req)
    88  }
    89  
    90  func (ws *WebSockets) ack(connID string, inflight *fftypes.EventDeliveryResponse) error {
    91  	return ws.callbacks.DeliveryResponse(connID, *inflight)
    92  }
    93  
    94  func (ws *WebSockets) start(connID string, start *fftypes.WSClientActionStartPayload) error {
    95  	if start.Namespace == "" || (!start.Ephemeral && start.Name == "") {
    96  		return i18n.NewError(ws.ctx, i18n.MsgWSInvalidStartAction)
    97  	}
    98  	if start.Ephemeral {
    99  		return ws.callbacks.EphemeralSubscription(connID, start.Namespace, start.Filter, start.Options)
   100  	}
   101  	return ws.callbacks.RegisterConnection(connID, func(sr fftypes.SubscriptionRef) bool {
   102  		return sr.Namespace == start.Namespace && sr.Name == start.Name
   103  	})
   104  }
   105  
   106  func (ws *WebSockets) connClosed(connID string) {
   107  	ws.connMux.Lock()
   108  	delete(ws.connections, connID)
   109  	ws.connMux.Unlock()
   110  	// Drop lock before calling back
   111  	ws.callbacks.ConnnectionClosed(connID)
   112  }
   113  
   114  func (ws *WebSockets) WaitClosed() {
   115  	closedConnections := []*websocketConnection{}
   116  	ws.connMux.Lock()
   117  	for _, ws := range ws.connections {
   118  		closedConnections = append(closedConnections, ws)
   119  	}
   120  	ws.connMux.Unlock()
   121  	for _, ws := range closedConnections {
   122  		ws.waitClose()
   123  	}
   124  }