github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/wsclient/wsclient.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package wsclient
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"net/url"
    26  
    27  	"github.com/gorilla/websocket"
    28  	"github.com/kaleido-io/firefly/internal/config"
    29  	"github.com/kaleido-io/firefly/internal/i18n"
    30  	"github.com/kaleido-io/firefly/internal/log"
    31  	"github.com/kaleido-io/firefly/internal/restclient"
    32  	"github.com/kaleido-io/firefly/internal/retry"
    33  )
    34  
    35  type WSAuthConfig struct {
    36  	Username string `json:"username,omitempty"`
    37  	Password string `json:"password,omitempty"`
    38  }
    39  
    40  type WSClient interface {
    41  	Connect() error
    42  	Receive() <-chan []byte
    43  	URL() string
    44  	SetURL(url string)
    45  	Send(ctx context.Context, message []byte) error
    46  	Close()
    47  }
    48  
    49  type wsClient struct {
    50  	ctx                  context.Context
    51  	headers              http.Header
    52  	url                  string
    53  	initialRetryAttempts int
    54  	wsdialer             *websocket.Dialer
    55  	wsconn               *websocket.Conn
    56  	retry                retry.Retry
    57  	closed               bool
    58  	receive              chan []byte
    59  	send                 chan []byte
    60  	sendDone             chan []byte
    61  	closing              chan struct{}
    62  	afterConnect         WSPostConnectHandler
    63  }
    64  
    65  // WSPostConnectHandler will be called after every connect/reconnect. Can send data over ws, but must not block listening for data on the ws.
    66  type WSPostConnectHandler func(ctx context.Context, w WSClient) error
    67  
    68  func New(ctx context.Context, prefix config.Prefix, afterConnect WSPostConnectHandler) (WSClient, error) {
    69  
    70  	wsURL, err := buildWSUrl(ctx, prefix)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	w := &wsClient{
    76  		ctx: ctx,
    77  		url: wsURL,
    78  		wsdialer: &websocket.Dialer{
    79  			ReadBufferSize:  int(prefix.GetByteSize(WSConfigKeyReadBufferSize)),
    80  			WriteBufferSize: int(prefix.GetByteSize(WSConfigKeyWriteBufferSize)),
    81  		},
    82  		retry: retry.Retry{
    83  			InitialDelay: prefix.GetDuration(restclient.HTTPConfigRetryInitDelay),
    84  			MaximumDelay: prefix.GetDuration(restclient.HTTPConfigRetryMaxDelay),
    85  		},
    86  		initialRetryAttempts: prefix.GetInt(WSConfigKeyInitialConnectAttempts),
    87  		headers:              make(http.Header),
    88  		receive:              make(chan []byte),
    89  		send:                 make(chan []byte),
    90  		closing:              make(chan struct{}),
    91  		afterConnect:         afterConnect,
    92  	}
    93  	for k, v := range prefix.GetObject(restclient.HTTPConfigHeaders) {
    94  		if vs, ok := v.(string); ok {
    95  			w.headers.Set(k, vs)
    96  		}
    97  	}
    98  	authUsername := prefix.GetString(restclient.HTTPConfigAuthUsername)
    99  	authPassword := prefix.GetString(restclient.HTTPConfigAuthPassword)
   100  	if authUsername != "" && authPassword != "" {
   101  		w.headers.Set("Authorization", fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", authUsername, authPassword)))))
   102  	}
   103  
   104  	return w, nil
   105  }
   106  
   107  func (w *wsClient) Connect() error {
   108  
   109  	if err := w.connect(true); err != nil {
   110  		return err
   111  	}
   112  
   113  	go w.receiveReconnectLoop()
   114  
   115  	return nil
   116  }
   117  
   118  func (w *wsClient) Close() {
   119  	if !w.closed {
   120  		w.closed = true
   121  		close(w.closing)
   122  		c := w.wsconn
   123  		if c != nil {
   124  			_ = c.Close()
   125  		}
   126  	}
   127  }
   128  
   129  // Receive returns
   130  func (w *wsClient) Receive() <-chan []byte {
   131  	return w.receive
   132  }
   133  
   134  func (w *wsClient) URL() string {
   135  	return w.url
   136  }
   137  
   138  func (w *wsClient) SetURL(url string) {
   139  	w.url = url
   140  }
   141  
   142  func (w *wsClient) Send(ctx context.Context, message []byte) error {
   143  	// Send
   144  	select {
   145  	case w.send <- message:
   146  		return nil
   147  	case <-ctx.Done():
   148  		return i18n.NewError(ctx, i18n.MsgWSSendTimedOut)
   149  	case <-w.closing:
   150  		return i18n.NewError(ctx, i18n.MsgWSClosing)
   151  	}
   152  }
   153  
   154  func buildWSUrl(ctx context.Context, prefix config.Prefix) (string, error) {
   155  	urlString := prefix.GetString(restclient.HTTPConfigURL)
   156  	u, err := url.Parse(urlString)
   157  	if err != nil {
   158  		return "", i18n.WrapError(ctx, err, i18n.MsgInvalidURL, urlString)
   159  	}
   160  	wsPath := prefix.GetString(WSConfigKeyPath)
   161  	if wsPath != "" {
   162  		u.Path = wsPath
   163  	}
   164  	if u.Scheme == "http" {
   165  		u.Scheme = "ws"
   166  	}
   167  	if u.Scheme == "https" {
   168  		u.Scheme = "wss"
   169  	}
   170  	return u.String(), nil
   171  }
   172  
   173  func (w *wsClient) connect(initial bool) error {
   174  	l := log.L(w.ctx)
   175  	return w.retry.DoCustomLog(w.ctx, func(attempt int) (retry bool, err error) {
   176  		if w.closed {
   177  			return false, i18n.NewError(w.ctx, i18n.MsgWSClosing)
   178  		}
   179  		var res *http.Response
   180  		w.wsconn, res, err = w.wsdialer.Dial(w.url, w.headers)
   181  		if err != nil {
   182  			var b []byte
   183  			var status = -1
   184  			if res != nil {
   185  				b, _ = ioutil.ReadAll(res.Body)
   186  				res.Body.Close()
   187  				status = res.StatusCode
   188  			}
   189  			l.Warnf("WS %s connect attempt %d failed [%d]: %s", w.url, attempt, status, string(b))
   190  			return !initial || attempt > w.initialRetryAttempts, i18n.WrapError(w.ctx, err, i18n.MsgWSConnectFailed)
   191  		}
   192  		l.Infof("WS %s connected", w.url)
   193  		return false, nil
   194  	})
   195  }
   196  
   197  func (w *wsClient) readLoop() {
   198  	l := log.L(w.ctx)
   199  	for {
   200  		mt, message, err := w.wsconn.ReadMessage()
   201  
   202  		// Check there's not a pending send message we need to return
   203  		// before returning any error (do not block)
   204  		select {
   205  		case <-w.sendDone:
   206  			l.Debugf("WS %s closing reader after send error", w.url)
   207  			return
   208  		default:
   209  		}
   210  
   211  		// return any error
   212  		if err != nil {
   213  			l.Errorf("WS %s closed: %s", w.url, err)
   214  			return
   215  		}
   216  
   217  		// Pass the message to the consumer
   218  		l.Tracef("WS %s read (mt=%d): %s", w.url, mt, message)
   219  		w.receive <- message
   220  	}
   221  }
   222  
   223  func (w *wsClient) sendLoop(receiverDone chan struct{}) {
   224  	l := log.L(w.ctx)
   225  	defer close(w.sendDone)
   226  
   227  	for {
   228  		select {
   229  		case message := <-w.send:
   230  			l.Tracef("WS sending: %s", message)
   231  			if err := w.wsconn.WriteMessage(websocket.TextMessage, message); err != nil {
   232  				l.Errorf("WS %s send failed: %s", w.url, err)
   233  				return
   234  			}
   235  		case <-receiverDone:
   236  			l.Debugf("WS %s send loop exiting", w.url)
   237  			return
   238  		}
   239  	}
   240  }
   241  
   242  func (w *wsClient) receiveReconnectLoop() {
   243  	l := log.L(w.ctx)
   244  	defer close(w.receive)
   245  	for !w.closed {
   246  		// Start the sender, letting it close without blocking sending a notifiation on the sendDone
   247  		w.sendDone = make(chan []byte, 1)
   248  		receiverDone := make(chan struct{})
   249  		go w.sendLoop(receiverDone)
   250  
   251  		// Call the reconnect processor
   252  		var err error
   253  		if w.afterConnect != nil {
   254  			err = w.afterConnect(w.ctx, w)
   255  		}
   256  
   257  		if err == nil {
   258  			// Synchronously invoke the reader, as it's important we react immediately to any error there.
   259  			w.readLoop()
   260  			close(receiverDone)
   261  
   262  			// Ensure the connection is closed after the receiver exits
   263  			err = w.wsconn.Close()
   264  			if err != nil {
   265  				l.Debugf("WS %s close failed: %s", w.url, err)
   266  			}
   267  			<-w.sendDone
   268  			w.sendDone = nil
   269  			w.wsconn = nil
   270  		}
   271  
   272  		// Go into reconnect
   273  		if !w.closed {
   274  			err = w.connect(false)
   275  			if err != nil {
   276  				l.Debugf("WS %s exiting: %s", w.url, err)
   277  				return
   278  			}
   279  		}
   280  	}
   281  }