github.com/m-lab/locate@v0.17.6/connection/connection.go (about) 1 // Package connection provides a Websocket that will automatically 2 // reconnect if the connection is dropped. 3 package connection 4 5 import ( 6 "errors" 7 "log" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 "github.com/cenkalti/backoff/v4" 14 "github.com/gorilla/websocket" 15 "github.com/m-lab/locate/metrics" 16 "github.com/m-lab/locate/static" 17 ) 18 19 var ( 20 // ErrNotDialed is returned when WriteMessage is called, but 21 // the websocket has not been created yet (call Dial). 22 ErrNotDailed = errors.New("websocket not created yet, please call Dial()") 23 // retryErrors contains the list of errors that may become successful 24 // if the request is retried. 25 retryErrors = map[int]bool{408: true, 425: true, 500: true, 502: true, 503: true, 504: true} 26 ) 27 28 // Conn contains the state needed to connect, reconnect, and send 29 // messages. 30 // Default values must be updated before calling `Dial`. 31 type Conn struct { 32 // InitialInterval is the first interval at which the backoff starts 33 // running. 34 InitialInterval time.Duration 35 // RandomizationFactor is used to create the range of values: 36 // [currentInterval - randomizationFactor * currentInterval, 37 // currentInterval + randomizationFactor * currentInterval] and picking 38 // a random value from the range. 39 RandomizationFactor float64 40 // Multiplier is used to increment the backoff interval by multiplying it. 41 Multiplier float64 42 // MaxInterval is an interval such that, once reached, the backoff will 43 // retry with a constant delay of MaxInterval. 44 MaxInterval time.Duration 45 // MaxElapsedTime is the amount of time after which the ExponentialBackOff 46 // returns Stop. It never stops if MaxElapsedTime == 0. 47 MaxElapsedTime time.Duration 48 // DialMessage is the message sent when the connection is started. 49 DialMessage interface{} 50 dialer websocket.Dialer 51 ws *websocket.Conn 52 url url.URL 53 header http.Header 54 ticker time.Ticker 55 mu sync.Mutex 56 isDialed bool 57 isConnected bool 58 } 59 60 // NewConn creates a new Conn with default values. 61 func NewConn() *Conn { 62 c := &Conn{ 63 InitialInterval: static.BackoffInitialInterval, 64 RandomizationFactor: static.BackoffRandomizationFactor, 65 Multiplier: static.BackoffMultiplier, 66 MaxInterval: static.BackoffMaxInterval, 67 MaxElapsedTime: static.BackoffMaxElapsedTime, 68 } 69 return c 70 } 71 72 // Dial creates a new persistent client connection and sets 73 // the necessary state for future reconnections. It also 74 // starts a goroutine to reset the number of reconnections. 75 // 76 // A call to Dial is a prerequisite to writing any messages. 77 // The function only needs to be called once on start to create the 78 // connection. Alternatively, if Close is called, Dial will have to 79 // be called again if the connection needs to be recreated. 80 // 81 // The function returns an error if the url is invalid or if 82 // a 4XX error (except 408 and 425) is received in the HTTP 83 // response. 84 func (c *Conn) Dial(address string, header http.Header, dialMsg interface{}) error { 85 u, err := url.ParseRequestURI(address) 86 if err != nil || (u.Scheme != "ws" && u.Scheme != "wss") { 87 return errors.New("malformed ws or wss URL") 88 } 89 c.url = *u 90 c.DialMessage = dialMsg 91 c.header = header 92 c.dialer = websocket.Dialer{} 93 c.isDialed = true 94 return c.connect() 95 } 96 97 // WriteMessage sends the JSON encoding of `data` as a message. 98 // If the write fails or a disconnect has been detected, it will 99 // close the connection and try to reconnect and resend the 100 // message. 101 // 102 // The write will fail under the following conditions: 103 // 1. The client has not called Dial (ErrNotDialed). 104 // 2. The connection is disconnected and it was not able to 105 // reconnect. 106 // 3. The write call in the websocket package failed 107 // (gorilla/websocket error). 108 func (c *Conn) WriteMessage(messageType int, data interface{}) error { 109 if !c.isDialed { 110 return ErrNotDailed 111 } 112 113 // If a disconnect has already been detected, try to reconnect. 114 if !c.IsConnected() { 115 if err := c.closeAndReconnect(); err != nil { 116 return err 117 } 118 } 119 120 // If the write fails, reconnect and send the message again. 121 if err := c.write(messageType, data); err != nil { 122 if err := c.closeAndReconnect(); err != nil { 123 return err 124 } 125 return c.write(messageType, data) 126 } 127 return nil 128 } 129 130 // IsConnected returns the WebSocket connection state. 131 func (c *Conn) IsConnected() bool { 132 return c.isConnected 133 } 134 135 // Close closes the network connection and cleans up private 136 // resources after the connection is done. 137 func (c *Conn) Close() error { 138 if c.isDialed { 139 c.isDialed = false 140 } 141 return c.close() 142 } 143 144 // closeAndReconnect calls close and reconnects. 145 func (c *Conn) closeAndReconnect() error { 146 err := c.close() 147 if err != nil { 148 return err 149 } 150 return c.connect() 151 } 152 153 // close closes the underlying network connection without 154 // sending or waiting for a close frame. 155 func (c *Conn) close() error { 156 if c.IsConnected() { 157 c.isConnected = false 158 if c.ws != nil { 159 return c.ws.Close() 160 } 161 } 162 return nil 163 } 164 165 // connect creates a new client connection and sends the 166 // registration message. 167 // In case of failure, it uses an exponential backoff to 168 // increase the duration of retry attempts. 169 func (c *Conn) connect() error { 170 b := c.getBackoff() 171 ticker := backoff.NewTicker(b) 172 173 var ws *websocket.Conn 174 var resp *http.Response 175 var err error 176 for range ticker.C { 177 ws, resp, err = c.dialer.Dial(c.url.String(), c.header) 178 if err != nil { 179 if resp != nil && !retryErrors[resp.StatusCode] { 180 log.Printf("error trying to establish a connection with %s, err: %v, status: %d", 181 c.url.String(), err, resp.StatusCode) 182 metrics.ConnectionRequestsTotal.WithLabelValues("error").Inc() 183 ticker.Stop() 184 return err 185 } 186 log.Printf("could not establish a connection with %s (will retry), err: %v", 187 c.url.String(), err) 188 metrics.ConnectionRequestsTotal.WithLabelValues("retry").Inc() 189 continue 190 } 191 192 c.ws = ws 193 c.isConnected = true 194 log.Printf("successfully established a connection with %s", c.url.String()) 195 metrics.ConnectionRequestsTotal.WithLabelValues("OK").Inc() 196 ticker.Stop() 197 } 198 199 if c.isConnected { 200 err = c.write(websocket.TextMessage, c.DialMessage) 201 } 202 return err 203 } 204 205 // write is a helper function that gets a writer using NextWriter, 206 // writes the message and closes the writer. 207 // It returns an error if the calls to NextWriter or WriteJSON 208 // return errors. 209 func (c *Conn) write(messageType int, data interface{}) error { 210 // We want to identify and return write errors as soon as they occur. 211 // The supported interface for WriteMessage does not do that. 212 // Therefore, we are using NextWriter explicitly with Close 213 // to update the error. 214 // NextWriter is called with a PingMessage type because it is 215 // effectively a no-op, while using other message types can 216 // cause side-effects (e.g, loading an empty msg to the buffer). 217 w, err := c.ws.NextWriter(websocket.PingMessage) 218 if err == nil { 219 err = c.ws.WriteJSON(data) 220 w.Close() 221 } 222 return err 223 } 224 225 // getBackoff returns a backoff implementation that increases the 226 // backoff period for each retry attempt using a randomization function 227 // that grows exponentially. 228 func (c *Conn) getBackoff() *backoff.ExponentialBackOff { 229 b := backoff.NewExponentialBackOff() 230 b.InitialInterval = c.InitialInterval 231 b.RandomizationFactor = c.RandomizationFactor 232 b.Multiplier = c.Multiplier 233 b.MaxInterval = c.MaxInterval 234 b.MaxElapsedTime = c.MaxElapsedTime 235 return b 236 }