github.com/stampzilla/stampzilla-go@v2.0.0-rc9+incompatible/pkg/websocket/websocket.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "net/http" 8 "sync" 9 "time" 10 11 "github.com/gorilla/websocket" 12 "github.com/sirupsen/logrus" 13 ) 14 15 const ( 16 // Time allowed to write a message to the peer. 17 writeWait = 10 * time.Second 18 19 // Time allowed to read the next pong message from the peer. 20 pongWait = 10 * time.Second 21 22 // Send pings to peer with this period. Must be less than pongWait. 23 pingPeriod = (pongWait * 9) / 10 24 25 reconnectWait = 2 * time.Second 26 ) 27 28 // Websocket implements a websocket client 29 type Websocket interface { 30 OnConnect(cb func()) 31 ConnectContext(ctx context.Context, addr string, headers http.Header) error 32 ConnectWithRetry(parentCtx context.Context, addr string, headers http.Header) 33 Wait() 34 Read() <-chan []byte 35 // WriteJSON writes interface{} encoded as JSON to our connection 36 WriteJSON(v interface{}) error 37 SetTLSConfig(c *tls.Config) 38 } 39 40 type websocketClient struct { 41 conn *websocket.Conn 42 tlsClientConfig *tls.Config 43 write chan func() 44 read chan []byte 45 wg *sync.WaitGroup 46 disconnected chan error 47 connected chan struct{} 48 onConnect func() 49 sync.Mutex 50 } 51 52 // New creates a new Websocket 53 func New() Websocket { 54 return &websocketClient{ 55 write: make(chan func()), 56 read: make(chan []byte, 100), 57 wg: &sync.WaitGroup{}, 58 disconnected: make(chan error), 59 connected: make(chan struct{}), 60 } 61 } 62 63 func (ws *websocketClient) SetTLSConfig(c *tls.Config) { 64 ws.tlsClientConfig = c 65 } 66 67 func (ws *websocketClient) OnConnect(cb func()) { 68 ws.Lock() 69 ws.onConnect = cb 70 ws.Unlock() 71 } 72 func (ws *websocketClient) getOnConnect() func() { 73 ws.Lock() 74 defer ws.Unlock() 75 return ws.onConnect 76 } 77 78 func (ws *websocketClient) ConnectContext(ctx context.Context, addr string, headers http.Header) error { 79 var err error 80 var c *websocket.Conn 81 logrus.Info("websocket: connecting to ", addr) 82 if ws.tlsClientConfig != nil { 83 dialer := &websocket.Dialer{ 84 Proxy: http.ProxyFromEnvironment, 85 HandshakeTimeout: 45 * time.Second, 86 TLSClientConfig: ws.tlsClientConfig, 87 } 88 c, _, err = dialer.DialContext(ctx, addr, headers) 89 } else { 90 c, _, err = websocket.DefaultDialer.DialContext(ctx, addr, headers) 91 } 92 if err != nil { 93 ws.wasDisconnected(err) 94 return err 95 } 96 logrus.Infof("websocket: connected to %s", addr) 97 ws.wasConnected() 98 ws.conn = c 99 ws.readPump() 100 ws.writePump(ctx) <- struct{}{} 101 102 if oncon := ws.getOnConnect(); oncon != nil { 103 oncon() 104 } 105 return nil 106 } 107 108 // ConnectWithRetry tries to connect and blocks until connected. 109 // if disconnected because an error tries to reconnect again every 5th second 110 func (ws *websocketClient) ConnectWithRetry(parentCtx context.Context, addr string, headers http.Header) { 111 112 ctx, cancel := context.WithCancel(parentCtx) 113 ws.wg.Add(1) 114 go func() { 115 defer ws.wg.Done() 116 for { 117 select { 118 case <-parentCtx.Done(): 119 logrus.Info("websocket: stopping reconnect because err: ", parentCtx.Err()) 120 return 121 case err := <-ws.disconnected: 122 cancel() // Stop any write/read pumps so we dont get duplicate write panic 123 logrus.Error("websocket: disconnected") 124 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 125 logrus.Info("websocket: Skipping reconnect due to CloseNormalClosure") 126 return 127 } 128 logrus.Info("websocket: Reconnect because error: ", err) 129 go func() { 130 time.Sleep(5 * time.Second) 131 ctx, cancel = context.WithCancel(parentCtx) 132 err := ws.ConnectContext(ctx, addr, headers) 133 if err != nil { 134 logrus.Error("websocket: Reconnect failed with error: ", err) 135 } 136 }() 137 } 138 } 139 }() 140 go ws.ConnectContext(ctx, addr, headers) 141 select { 142 case <-parentCtx.Done(): 143 return 144 case <-ws.connected: 145 return 146 } 147 } 148 149 func (ws *websocketClient) Wait() { 150 ws.wg.Wait() 151 } 152 153 func (ws *websocketClient) Read() <-chan []byte { 154 return ws.read 155 } 156 157 // WriteJSON writes interface{} encoded as JSON to our connection 158 func (ws *websocketClient) WriteJSON(v interface{}) error { 159 errCh := make(chan error, 1) 160 select { 161 case ws.write <- func() { 162 err := ws.conn.WriteJSON(v) 163 errCh <- err 164 }: 165 case <-time.After(time.Millisecond * 10): 166 errCh <- fmt.Errorf("websocket: no one listening on write channel") 167 } 168 return <-errCh 169 } 170 171 func (ws *websocketClient) readPump() { 172 ws.wg.Add(1) 173 go func() { 174 defer ws.wg.Done() 175 ws.conn.SetReadDeadline(time.Now().Add(pongWait)) 176 ws.conn.SetPongHandler(func(string) error { ws.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 177 for { 178 _, message, err := ws.conn.ReadMessage() 179 if err != nil { 180 logrus.Error("websocket: readPump error:", err) 181 ws.wasDisconnected(err) 182 return 183 } 184 logrus.Debugf("websocket: readPump got msg: %s", message) 185 select { 186 case ws.read <- message: 187 default: 188 } 189 } 190 }() 191 } 192 193 func (ws *websocketClient) writePump(ctx context.Context) chan struct{} { 194 ready := make(chan struct{}) 195 ws.wg.Add(1) 196 go func() { 197 defer ws.wg.Done() 198 ticker := time.NewTicker(pingPeriod) 199 defer ticker.Stop() 200 for { 201 select { 202 case t := <-ws.write: 203 t() 204 case <-ctx.Done(): 205 logrus.Error("websocket: Stopping writePump because err: ", ctx.Err()) 206 err := ws.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 207 if err != nil { 208 logrus.Error("websocket: write close:", err) 209 return 210 } 211 return 212 case <-ticker.C: 213 if err := ws.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { 214 logrus.Error("websocket: ping:", err) 215 } 216 case <-ready: 217 } 218 } 219 }() 220 return ready 221 } 222 223 func (ws *websocketClient) wasDisconnected(err error) { 224 select { 225 case ws.disconnected <- err: 226 default: 227 } 228 } 229 230 func (ws *websocketClient) wasConnected() { 231 select { 232 case ws.connected <- struct{}{}: 233 default: 234 } 235 }