github.com/m-lab/locate@v0.17.6/connection/connection.go (about)

     1  // Package connection provides a Websocket that will automatically
     2  // reconnect if the connection is dropped.
     3  package connection
     4  
     5  import (
     6  	"errors"
     7  	"log"
     8  	"net/http"
     9  	"net/url"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/cenkalti/backoff/v4"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/m-lab/locate/metrics"
    16  	"github.com/m-lab/locate/static"
    17  )
    18  
    19  var (
    20  	// ErrNotDialed is returned when WriteMessage is called, but
    21  	// the websocket has not been created yet (call Dial).
    22  	ErrNotDailed = errors.New("websocket not created yet, please call Dial()")
    23  	// retryErrors contains the list of errors that may become successful
    24  	// if the request is retried.
    25  	retryErrors = map[int]bool{408: true, 425: true, 500: true, 502: true, 503: true, 504: true}
    26  )
    27  
    28  // Conn contains the state needed to connect, reconnect, and send
    29  // messages.
    30  // Default values must be updated before calling `Dial`.
    31  type Conn struct {
    32  	// InitialInterval is the first interval at which the backoff starts
    33  	// running.
    34  	InitialInterval time.Duration
    35  	// RandomizationFactor is used to create the range of values:
    36  	// [currentInterval - randomizationFactor * currentInterval,
    37  	// currentInterval + randomizationFactor * currentInterval] and picking
    38  	// a random value from the range.
    39  	RandomizationFactor float64
    40  	// Multiplier is used to increment the backoff interval by multiplying it.
    41  	Multiplier float64
    42  	// MaxInterval is an interval such that, once reached, the backoff will
    43  	// retry with a constant delay of MaxInterval.
    44  	MaxInterval time.Duration
    45  	// MaxElapsedTime is the amount of time after which the ExponentialBackOff
    46  	// returns Stop. It never stops if MaxElapsedTime == 0.
    47  	MaxElapsedTime time.Duration
    48  	// DialMessage is the message sent when the connection is started.
    49  	DialMessage interface{}
    50  	dialer      websocket.Dialer
    51  	ws          *websocket.Conn
    52  	url         url.URL
    53  	header      http.Header
    54  	ticker      time.Ticker
    55  	mu          sync.Mutex
    56  	isDialed    bool
    57  	isConnected bool
    58  }
    59  
    60  // NewConn creates a new Conn with default values.
    61  func NewConn() *Conn {
    62  	c := &Conn{
    63  		InitialInterval:     static.BackoffInitialInterval,
    64  		RandomizationFactor: static.BackoffRandomizationFactor,
    65  		Multiplier:          static.BackoffMultiplier,
    66  		MaxInterval:         static.BackoffMaxInterval,
    67  		MaxElapsedTime:      static.BackoffMaxElapsedTime,
    68  	}
    69  	return c
    70  }
    71  
    72  // Dial creates a new persistent client connection and sets
    73  // the necessary state for future reconnections. It also
    74  // starts a goroutine to reset the number of reconnections.
    75  //
    76  // A call to Dial is a prerequisite to writing any messages.
    77  // The function only needs to be called once on start to create the
    78  // connection. Alternatively, if Close is called, Dial will have to
    79  // be called again if the connection needs to be recreated.
    80  //
    81  // The function returns an error if the url is invalid or if
    82  // a 4XX error (except 408 and 425) is received in the HTTP
    83  // response.
    84  func (c *Conn) Dial(address string, header http.Header, dialMsg interface{}) error {
    85  	u, err := url.ParseRequestURI(address)
    86  	if err != nil || (u.Scheme != "ws" && u.Scheme != "wss") {
    87  		return errors.New("malformed ws or wss URL")
    88  	}
    89  	c.url = *u
    90  	c.DialMessage = dialMsg
    91  	c.header = header
    92  	c.dialer = websocket.Dialer{}
    93  	c.isDialed = true
    94  	return c.connect()
    95  }
    96  
    97  // WriteMessage sends the JSON encoding of `data` as a message.
    98  // If the write fails or a disconnect has been detected, it will
    99  // close the connection and try to reconnect and resend the
   100  // message.
   101  //
   102  // The write will fail under the following conditions:
   103  //  1. The client has not called Dial (ErrNotDialed).
   104  //  2. The connection is disconnected and it was not able to
   105  //     reconnect.
   106  //  3. The write call in the websocket package failed
   107  //     (gorilla/websocket error).
   108  func (c *Conn) WriteMessage(messageType int, data interface{}) error {
   109  	if !c.isDialed {
   110  		return ErrNotDailed
   111  	}
   112  
   113  	// If a disconnect has already been detected, try to reconnect.
   114  	if !c.IsConnected() {
   115  		if err := c.closeAndReconnect(); err != nil {
   116  			return err
   117  		}
   118  	}
   119  
   120  	// If the write fails, reconnect and send the message again.
   121  	if err := c.write(messageType, data); err != nil {
   122  		if err := c.closeAndReconnect(); err != nil {
   123  			return err
   124  		}
   125  		return c.write(messageType, data)
   126  	}
   127  	return nil
   128  }
   129  
   130  // IsConnected returns the WebSocket connection state.
   131  func (c *Conn) IsConnected() bool {
   132  	return c.isConnected
   133  }
   134  
   135  // Close closes the network connection and cleans up private
   136  // resources after the connection is done.
   137  func (c *Conn) Close() error {
   138  	if c.isDialed {
   139  		c.isDialed = false
   140  	}
   141  	return c.close()
   142  }
   143  
   144  // closeAndReconnect calls close and reconnects.
   145  func (c *Conn) closeAndReconnect() error {
   146  	err := c.close()
   147  	if err != nil {
   148  		return err
   149  	}
   150  	return c.connect()
   151  }
   152  
   153  // close closes the underlying network connection without
   154  // sending or waiting for a close frame.
   155  func (c *Conn) close() error {
   156  	if c.IsConnected() {
   157  		c.isConnected = false
   158  		if c.ws != nil {
   159  			return c.ws.Close()
   160  		}
   161  	}
   162  	return nil
   163  }
   164  
   165  // connect creates a new client connection and sends the
   166  // registration message.
   167  // In case of failure, it uses an exponential backoff to
   168  // increase the duration of retry attempts.
   169  func (c *Conn) connect() error {
   170  	b := c.getBackoff()
   171  	ticker := backoff.NewTicker(b)
   172  
   173  	var ws *websocket.Conn
   174  	var resp *http.Response
   175  	var err error
   176  	for range ticker.C {
   177  		ws, resp, err = c.dialer.Dial(c.url.String(), c.header)
   178  		if err != nil {
   179  			if resp != nil && !retryErrors[resp.StatusCode] {
   180  				log.Printf("error trying to establish a connection with %s, err: %v, status: %d",
   181  					c.url.String(), err, resp.StatusCode)
   182  				metrics.ConnectionRequestsTotal.WithLabelValues("error").Inc()
   183  				ticker.Stop()
   184  				return err
   185  			}
   186  			log.Printf("could not establish a connection with %s (will retry), err: %v",
   187  				c.url.String(), err)
   188  			metrics.ConnectionRequestsTotal.WithLabelValues("retry").Inc()
   189  			continue
   190  		}
   191  
   192  		c.ws = ws
   193  		c.isConnected = true
   194  		log.Printf("successfully established a connection with %s", c.url.String())
   195  		metrics.ConnectionRequestsTotal.WithLabelValues("OK").Inc()
   196  		ticker.Stop()
   197  	}
   198  
   199  	if c.isConnected {
   200  		err = c.write(websocket.TextMessage, c.DialMessage)
   201  	}
   202  	return err
   203  }
   204  
   205  // write is a helper function that gets a writer using NextWriter,
   206  // writes the message and closes the writer.
   207  // It returns an error if the calls to NextWriter or WriteJSON
   208  // return errors.
   209  func (c *Conn) write(messageType int, data interface{}) error {
   210  	// We want to identify and return write errors as soon as they occur.
   211  	// The supported interface for WriteMessage does not do that.
   212  	// Therefore, we are using NextWriter explicitly with Close
   213  	// to update the error.
   214  	// NextWriter is called with a PingMessage type because it is
   215  	// effectively a no-op, while using other message types can
   216  	// cause side-effects (e.g, loading an empty msg to the buffer).
   217  	w, err := c.ws.NextWriter(websocket.PingMessage)
   218  	if err == nil {
   219  		err = c.ws.WriteJSON(data)
   220  		w.Close()
   221  	}
   222  	return err
   223  }
   224  
   225  // getBackoff returns a backoff implementation that increases the
   226  // backoff period for each retry attempt using a randomization function
   227  // that grows exponentially.
   228  func (c *Conn) getBackoff() *backoff.ExponentialBackOff {
   229  	b := backoff.NewExponentialBackOff()
   230  	b.InitialInterval = c.InitialInterval
   231  	b.RandomizationFactor = c.RandomizationFactor
   232  	b.Multiplier = c.Multiplier
   233  	b.MaxInterval = c.MaxInterval
   234  	b.MaxElapsedTime = c.MaxElapsedTime
   235  	return b
   236  }