github.com/stampzilla/stampzilla-go@v2.0.0-rc9+incompatible/pkg/websocket/websocket.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net/http"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/gorilla/websocket"
    12  	"github.com/sirupsen/logrus"
    13  )
    14  
    15  const (
    16  	// Time allowed to write a message to the peer.
    17  	writeWait = 10 * time.Second
    18  
    19  	// Time allowed to read the next pong message from the peer.
    20  	pongWait = 10 * time.Second
    21  
    22  	// Send pings to peer with this period. Must be less than pongWait.
    23  	pingPeriod = (pongWait * 9) / 10
    24  
    25  	reconnectWait = 2 * time.Second
    26  )
    27  
    28  // Websocket implements a websocket client
    29  type Websocket interface {
    30  	OnConnect(cb func())
    31  	ConnectContext(ctx context.Context, addr string, headers http.Header) error
    32  	ConnectWithRetry(parentCtx context.Context, addr string, headers http.Header)
    33  	Wait()
    34  	Read() <-chan []byte
    35  	// WriteJSON writes interface{} encoded as JSON to our connection
    36  	WriteJSON(v interface{}) error
    37  	SetTLSConfig(c *tls.Config)
    38  }
    39  
    40  type websocketClient struct {
    41  	conn            *websocket.Conn
    42  	tlsClientConfig *tls.Config
    43  	write           chan func()
    44  	read            chan []byte
    45  	wg              *sync.WaitGroup
    46  	disconnected    chan error
    47  	connected       chan struct{}
    48  	onConnect       func()
    49  	sync.Mutex
    50  }
    51  
    52  // New creates a new Websocket
    53  func New() Websocket {
    54  	return &websocketClient{
    55  		write:        make(chan func()),
    56  		read:         make(chan []byte, 100),
    57  		wg:           &sync.WaitGroup{},
    58  		disconnected: make(chan error),
    59  		connected:    make(chan struct{}),
    60  	}
    61  }
    62  
    63  func (ws *websocketClient) SetTLSConfig(c *tls.Config) {
    64  	ws.tlsClientConfig = c
    65  }
    66  
    67  func (ws *websocketClient) OnConnect(cb func()) {
    68  	ws.Lock()
    69  	ws.onConnect = cb
    70  	ws.Unlock()
    71  }
    72  func (ws *websocketClient) getOnConnect() func() {
    73  	ws.Lock()
    74  	defer ws.Unlock()
    75  	return ws.onConnect
    76  }
    77  
    78  func (ws *websocketClient) ConnectContext(ctx context.Context, addr string, headers http.Header) error {
    79  	var err error
    80  	var c *websocket.Conn
    81  	logrus.Info("websocket: connecting to ", addr)
    82  	if ws.tlsClientConfig != nil {
    83  		dialer := &websocket.Dialer{
    84  			Proxy:            http.ProxyFromEnvironment,
    85  			HandshakeTimeout: 45 * time.Second,
    86  			TLSClientConfig:  ws.tlsClientConfig,
    87  		}
    88  		c, _, err = dialer.DialContext(ctx, addr, headers)
    89  	} else {
    90  		c, _, err = websocket.DefaultDialer.DialContext(ctx, addr, headers)
    91  	}
    92  	if err != nil {
    93  		ws.wasDisconnected(err)
    94  		return err
    95  	}
    96  	logrus.Infof("websocket: connected to %s", addr)
    97  	ws.wasConnected()
    98  	ws.conn = c
    99  	ws.readPump()
   100  	ws.writePump(ctx) <- struct{}{}
   101  
   102  	if oncon := ws.getOnConnect(); oncon != nil {
   103  		oncon()
   104  	}
   105  	return nil
   106  }
   107  
   108  // ConnectWithRetry tries to connect and blocks until connected.
   109  // if disconnected because an error tries to reconnect again every 5th second
   110  func (ws *websocketClient) ConnectWithRetry(parentCtx context.Context, addr string, headers http.Header) {
   111  
   112  	ctx, cancel := context.WithCancel(parentCtx)
   113  	ws.wg.Add(1)
   114  	go func() {
   115  		defer ws.wg.Done()
   116  		for {
   117  			select {
   118  			case <-parentCtx.Done():
   119  				logrus.Info("websocket: stopping reconnect because err: ", parentCtx.Err())
   120  				return
   121  			case err := <-ws.disconnected:
   122  				cancel() // Stop any write/read pumps so we dont get duplicate write panic
   123  				logrus.Error("websocket: disconnected")
   124  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   125  					logrus.Info("websocket: Skipping reconnect due to CloseNormalClosure")
   126  					return
   127  				}
   128  				logrus.Info("websocket: Reconnect because error: ", err)
   129  				go func() {
   130  					time.Sleep(5 * time.Second)
   131  					ctx, cancel = context.WithCancel(parentCtx)
   132  					err := ws.ConnectContext(ctx, addr, headers)
   133  					if err != nil {
   134  						logrus.Error("websocket: Reconnect failed with error: ", err)
   135  					}
   136  				}()
   137  			}
   138  		}
   139  	}()
   140  	go ws.ConnectContext(ctx, addr, headers)
   141  	select {
   142  	case <-parentCtx.Done():
   143  		return
   144  	case <-ws.connected:
   145  		return
   146  	}
   147  }
   148  
   149  func (ws *websocketClient) Wait() {
   150  	ws.wg.Wait()
   151  }
   152  
   153  func (ws *websocketClient) Read() <-chan []byte {
   154  	return ws.read
   155  }
   156  
   157  // WriteJSON writes interface{} encoded as JSON to our connection
   158  func (ws *websocketClient) WriteJSON(v interface{}) error {
   159  	errCh := make(chan error, 1)
   160  	select {
   161  	case ws.write <- func() {
   162  		err := ws.conn.WriteJSON(v)
   163  		errCh <- err
   164  	}:
   165  	case <-time.After(time.Millisecond * 10):
   166  		errCh <- fmt.Errorf("websocket: no one listening on write channel")
   167  	}
   168  	return <-errCh
   169  }
   170  
   171  func (ws *websocketClient) readPump() {
   172  	ws.wg.Add(1)
   173  	go func() {
   174  		defer ws.wg.Done()
   175  		ws.conn.SetReadDeadline(time.Now().Add(pongWait))
   176  		ws.conn.SetPongHandler(func(string) error { ws.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
   177  		for {
   178  			_, message, err := ws.conn.ReadMessage()
   179  			if err != nil {
   180  				logrus.Error("websocket: readPump error:", err)
   181  				ws.wasDisconnected(err)
   182  				return
   183  			}
   184  			logrus.Debugf("websocket: readPump got msg: %s", message)
   185  			select {
   186  			case ws.read <- message:
   187  			default:
   188  			}
   189  		}
   190  	}()
   191  }
   192  
   193  func (ws *websocketClient) writePump(ctx context.Context) chan struct{} {
   194  	ready := make(chan struct{})
   195  	ws.wg.Add(1)
   196  	go func() {
   197  		defer ws.wg.Done()
   198  		ticker := time.NewTicker(pingPeriod)
   199  		defer ticker.Stop()
   200  		for {
   201  			select {
   202  			case t := <-ws.write:
   203  				t()
   204  			case <-ctx.Done():
   205  				logrus.Error("websocket: Stopping writePump because err: ", ctx.Err())
   206  				err := ws.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   207  				if err != nil {
   208  					logrus.Error("websocket: write close:", err)
   209  					return
   210  				}
   211  				return
   212  			case <-ticker.C:
   213  				if err := ws.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
   214  					logrus.Error("websocket: ping:", err)
   215  				}
   216  			case <-ready:
   217  			}
   218  		}
   219  	}()
   220  	return ready
   221  }
   222  
   223  func (ws *websocketClient) wasDisconnected(err error) {
   224  	select {
   225  	case ws.disconnected <- err:
   226  	default:
   227  	}
   228  }
   229  
   230  func (ws *websocketClient) wasConnected() {
   231  	select {
   232  	case ws.connected <- struct{}{}:
   233  	default:
   234  	}
   235  }