github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/events/websockets/websocket_connection.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package websockets
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"io/ioutil"
    23  	"net/http"
    24  	"sync"
    25  
    26  	"github.com/gorilla/websocket"
    27  	"github.com/kaleido-io/firefly/internal/i18n"
    28  	"github.com/kaleido-io/firefly/internal/log"
    29  	"github.com/kaleido-io/firefly/pkg/fftypes"
    30  )
    31  
    32  type websocketConnection struct {
    33  	ctx          context.Context
    34  	ws           *WebSockets
    35  	wsConn       *websocket.Conn
    36  	cancelCtx    func()
    37  	connID       string
    38  	sendMessages chan interface{}
    39  	senderDone   chan struct{}
    40  	autoAck      bool
    41  	startedCount int
    42  	inflight     []*fftypes.EventDeliveryResponse
    43  	mux          sync.Mutex
    44  	closed       bool
    45  }
    46  
    47  func newConnection(pCtx context.Context, ws *WebSockets, wsConn *websocket.Conn) *websocketConnection {
    48  	connID := fftypes.NewUUID().String()
    49  	ctx := log.WithLogField(pCtx, "websocket", connID)
    50  	ctx, cancelCtx := context.WithCancel(ctx)
    51  	wc := &websocketConnection{
    52  		ctx:          ctx,
    53  		ws:           ws,
    54  		wsConn:       wsConn,
    55  		cancelCtx:    cancelCtx,
    56  		connID:       connID,
    57  		sendMessages: make(chan interface{}),
    58  		senderDone:   make(chan struct{}),
    59  	}
    60  	go wc.sendLoop()
    61  	go wc.receiveLoop()
    62  	return wc
    63  }
    64  
    65  // processAutoStart gives a helper to specify query parameters to auto-start your subscription
    66  func (wc *websocketConnection) processAutoStart(req *http.Request) {
    67  	query := req.URL.Query()
    68  	ephemeral, hasEphemeral := req.URL.Query()["ephemeral"]
    69  	isEphemeral := hasEphemeral && (len(ephemeral) == 0 || ephemeral[0] != "false")
    70  	_, hasName := query["name"]
    71  	autoAck, hasAutoack := req.URL.Query()["autoack"]
    72  	isAutoack := hasAutoack && (len(autoAck) == 0 || autoAck[0] != "false")
    73  	if hasEphemeral || hasName {
    74  		err := wc.handleStart(&fftypes.WSClientActionStartPayload{
    75  			AutoAck:   &isAutoack,
    76  			Ephemeral: isEphemeral,
    77  			Namespace: query.Get("namespace"),
    78  			Name:      query.Get("name"),
    79  			Filter: fftypes.SubscriptionFilter{
    80  				Events: query.Get("filter.events"),
    81  				Topics: query.Get("filter.topics"),
    82  				Group:  query.Get("filter.group"),
    83  				Tag:    query.Get("filter.tag"),
    84  			},
    85  		})
    86  		if err != nil {
    87  			wc.protocolError(err)
    88  		}
    89  	}
    90  }
    91  
    92  func (wc *websocketConnection) sendLoop() {
    93  	l := log.L(wc.ctx)
    94  	defer close(wc.senderDone)
    95  	defer wc.close()
    96  	for {
    97  		select {
    98  		case msg, ok := <-wc.sendMessages:
    99  			if !ok {
   100  				l.Debugf("Sender closing")
   101  				return
   102  			}
   103  			l.Tracef("Sending: %+v", msg)
   104  			writer, err := wc.wsConn.NextWriter(websocket.TextMessage)
   105  			if err == nil {
   106  				err = json.NewEncoder(writer).Encode(msg)
   107  				_ = writer.Close()
   108  			}
   109  			if err != nil {
   110  				l.Errorf("Write failed on socket: %s", err)
   111  				return
   112  			}
   113  		case <-wc.ctx.Done():
   114  			l.Debugf("Sender closing - context cancelled")
   115  			return
   116  		}
   117  	}
   118  }
   119  
   120  func (wc *websocketConnection) receiveLoop() {
   121  	l := log.L(wc.ctx)
   122  	defer close(wc.sendMessages)
   123  	for {
   124  		var msgData []byte
   125  		var msgHeader fftypes.WSClientActionBase
   126  		_, reader, err := wc.wsConn.NextReader()
   127  		if err == nil {
   128  			msgData, err = ioutil.ReadAll(reader)
   129  			if err == nil {
   130  				err = json.Unmarshal(msgData, &msgHeader)
   131  				if err != nil {
   132  					// We can notify the client on this one, before we bail
   133  					wc.protocolError(i18n.WrapError(wc.ctx, err, i18n.MsgWSClientSentInvalidData))
   134  				}
   135  			}
   136  		}
   137  		if err != nil {
   138  			l.Errorf("Read failed: %s", err)
   139  			return
   140  		}
   141  		l.Tracef("Received: %s", string(msgData))
   142  		switch msgHeader.Type {
   143  		case fftypes.WSClientActionStart:
   144  			var msg fftypes.WSClientActionStartPayload
   145  			err = json.Unmarshal(msgData, &msg)
   146  			if err == nil {
   147  				err = wc.handleStart(&msg)
   148  			}
   149  		case fftypes.WSClientActionAck:
   150  			var msg fftypes.WSClientActionAckPayload
   151  			err = json.Unmarshal(msgData, &msg)
   152  			if err == nil {
   153  				err = wc.handleAck(&msg)
   154  			}
   155  		default:
   156  			err = i18n.NewError(wc.ctx, i18n.MsgWSClientUnknownAction, msgHeader.Type)
   157  		}
   158  		if err != nil {
   159  			wc.protocolError(i18n.WrapError(wc.ctx, err, i18n.MsgWSClientSentInvalidData))
   160  			l.Errorf("Invalid request sent on socket: %s", err)
   161  			return
   162  		}
   163  	}
   164  }
   165  
   166  func (wc *websocketConnection) dispatch(event *fftypes.EventDelivery) error {
   167  	inflight := &fftypes.EventDeliveryResponse{
   168  		ID:           event.ID,
   169  		Subscription: event.Subscription,
   170  	}
   171  
   172  	var autoAck bool
   173  	wc.mux.Lock()
   174  	autoAck = wc.autoAck
   175  	if !autoAck {
   176  		wc.inflight = append(wc.inflight, inflight)
   177  	}
   178  	wc.mux.Unlock()
   179  
   180  	err := wc.send(event)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	if autoAck {
   186  		return wc.ws.ack(wc.connID, inflight)
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  func (wc *websocketConnection) protocolError(err error) {
   193  	log.L(wc.ctx).Errorf("Sending protocol error to client: %s", err)
   194  	sendErr := wc.send(&fftypes.WSProtocolErrorPayload{
   195  		Type:  fftypes.WSProtocolErrorEventType,
   196  		Error: err.Error(),
   197  	})
   198  	if sendErr != nil {
   199  		log.L(wc.ctx).Errorf("Failed to send protocol error: %s", sendErr)
   200  	}
   201  }
   202  
   203  func (wc *websocketConnection) send(msg interface{}) error {
   204  	select {
   205  	case wc.sendMessages <- msg:
   206  		return nil
   207  	case <-wc.ctx.Done():
   208  		return i18n.NewError(wc.ctx, i18n.MsgWSClosing)
   209  	}
   210  }
   211  
   212  func (wc *websocketConnection) handleStart(start *fftypes.WSClientActionStartPayload) (err error) {
   213  	wc.mux.Lock()
   214  	if start.AutoAck != nil {
   215  		if *start.AutoAck != wc.autoAck && wc.startedCount > 0 {
   216  			return i18n.NewError(wc.ctx, i18n.MsgWSAutoAckChanged)
   217  		}
   218  		wc.autoAck = *start.AutoAck
   219  	}
   220  	wc.mux.Unlock()
   221  
   222  	err = wc.ws.start(wc.connID, start)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	wc.mux.Lock()
   227  	wc.startedCount++
   228  	wc.mux.Unlock()
   229  	return nil
   230  }
   231  
   232  func (wc *websocketConnection) checkAck(ack *fftypes.WSClientActionAckPayload) (*fftypes.EventDeliveryResponse, error) {
   233  	l := log.L(wc.ctx)
   234  	var inflight *fftypes.EventDeliveryResponse
   235  	wc.mux.Lock()
   236  	defer wc.mux.Unlock()
   237  
   238  	if wc.autoAck {
   239  		return nil, i18n.NewError(wc.ctx, i18n.MsgWSAutoAckEnabled)
   240  	}
   241  
   242  	if ack.ID != nil {
   243  		newInflight := make([]*fftypes.EventDeliveryResponse, 0, len(wc.inflight))
   244  		for _, candidate := range wc.inflight {
   245  			var match bool
   246  			if *candidate.ID == *ack.ID {
   247  				if ack.Subscription != nil {
   248  					// A subscription has been explicitly specified, so it must match
   249  					if (ack.Subscription.ID != nil && *ack.Subscription.ID == *candidate.Subscription.ID) ||
   250  						(ack.Subscription.Name == candidate.Subscription.Name && ack.Subscription.Namespace == candidate.Subscription.Namespace) {
   251  						match = true
   252  					}
   253  				} else {
   254  					// If there's more than one started subscription, that's a problem
   255  					if wc.startedCount != 1 {
   256  						l.Errorf("No subscription specified on ack, and there is not exactly one started subscription")
   257  						return nil, i18n.NewError(wc.ctx, i18n.MsgWSMsgSubNotMatched)
   258  					}
   259  					match = true
   260  				}
   261  			}
   262  			// Remove from the inflight list
   263  			if match {
   264  				inflight = candidate
   265  			} else {
   266  				newInflight = append(newInflight, candidate)
   267  			}
   268  		}
   269  		wc.inflight = newInflight
   270  	} else {
   271  		// Just ack the front of the queue
   272  		if len(wc.inflight) == 0 {
   273  			l.Errorf("Ack received, but no messages in flight")
   274  		} else {
   275  			inflight = wc.inflight[0]
   276  			wc.inflight = wc.inflight[1:]
   277  		}
   278  	}
   279  	if inflight == nil {
   280  		return nil, i18n.NewError(wc.ctx, i18n.MsgWSMsgSubNotMatched)
   281  	}
   282  	return inflight, nil
   283  }
   284  
   285  func (wc *websocketConnection) handleAck(ack *fftypes.WSClientActionAckPayload) error {
   286  	// Perform a locked set of check
   287  	inflight, err := wc.checkAck(ack)
   288  	if err != nil {
   289  		return err
   290  	}
   291  
   292  	// Deliver the ack to the core, now we're unlocked
   293  	return wc.ws.ack(wc.connID, inflight)
   294  }
   295  
   296  func (wc *websocketConnection) close() {
   297  	var didClosed bool
   298  	wc.mux.Lock()
   299  	if !wc.closed {
   300  		didClosed = true
   301  		wc.closed = true
   302  		_ = wc.wsConn.Close()
   303  		wc.cancelCtx()
   304  	}
   305  	wc.mux.Unlock()
   306  	// Drop lock before callback
   307  	if didClosed {
   308  		wc.ws.connClosed(wc.connID)
   309  	}
   310  }
   311  
   312  func (wc *websocketConnection) waitClose() {
   313  	<-wc.senderDone
   314  	<-wc.sendMessages
   315  }