github.com/icyphox/x@v0.0.355-0.20220311094250-029bd783e8b8/watcherx/websocket_server.go (about)

     1  package watcherx
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/gorilla/websocket"
    13  
    14  	"github.com/ory/herodot"
    15  )
    16  
    17  type (
    18  	eventChannelSlice struct {
    19  		sync.Mutex
    20  		cs []EventChannel
    21  	}
    22  	websocketWatcher struct {
    23  		wsWriteLock      sync.Mutex
    24  		wsReadLock       sync.Mutex
    25  		wsClientChannels eventChannelSlice
    26  	}
    27  )
    28  
    29  const (
    30  	messageSendNow     = "send values now"
    31  	messageSendNowDone = "done sending %d values"
    32  )
    33  
    34  func WatchAndServeWS(ctx context.Context, u *url.URL, writer herodot.Writer) (http.HandlerFunc, error) {
    35  	c := make(EventChannel)
    36  	watcher, err := Watch(ctx, u, c)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	w := &websocketWatcher{
    41  		wsClientChannels: eventChannelSlice{},
    42  	}
    43  	go w.broadcaster(ctx, c)
    44  	return w.serveWS(ctx, writer, watcher), nil
    45  }
    46  
    47  func (ww *websocketWatcher) broadcaster(ctx context.Context, c EventChannel) {
    48  	for {
    49  		select {
    50  		case <-ctx.Done():
    51  			return
    52  		case e := <-c:
    53  			ww.wsClientChannels.Lock()
    54  			for _, cc := range ww.wsClientChannels.cs {
    55  				cc <- e
    56  			}
    57  			ww.wsClientChannels.Unlock()
    58  		}
    59  	}
    60  }
    61  
    62  func (ww *websocketWatcher) readWebsocket(ws *websocket.Conn, c chan<- struct{}, watcher Watcher) {
    63  	for {
    64  		// blocking call to ReadMessage that waits for a close message
    65  		ww.wsReadLock.Lock()
    66  		_, msg, err := ws.ReadMessage()
    67  		ww.wsReadLock.Unlock()
    68  
    69  		switch errTyped := err.(type) {
    70  		case nil:
    71  			if string(msg) == messageSendNow {
    72  				done, err := watcher.DispatchNow()
    73  				if err != nil {
    74  					// we cant do much about this error
    75  					ww.wsWriteLock.Lock()
    76  					_ = ws.WriteJSON(&ErrorEvent{
    77  						error:  err,
    78  						source: "",
    79  					})
    80  					ww.wsWriteLock.Unlock()
    81  				}
    82  
    83  				go func() {
    84  					eventsSend := <-done
    85  
    86  					ww.wsWriteLock.Lock()
    87  					defer ww.wsWriteLock.Unlock()
    88  
    89  					// we cant do much about this error
    90  					_ = ws.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(messageSendNowDone, eventsSend)))
    91  				}()
    92  			}
    93  		case *websocket.CloseError:
    94  			if errTyped.Code == websocket.CloseNormalClosure {
    95  				close(c)
    96  				return
    97  			}
    98  		case *net.OpError:
    99  			if errTyped.Op == "read" && strings.Contains(errTyped.Err.Error(), "closed") {
   100  				// the context got canceled and therefore the connection closed
   101  				close(c)
   102  				return
   103  			}
   104  		default:
   105  			// some other unexpected error, best we can do is return
   106  			return
   107  		}
   108  	}
   109  }
   110  
   111  func (ww *websocketWatcher) serveWS(ctx context.Context, writer herodot.Writer, watcher Watcher) func(w http.ResponseWriter, r *http.Request) {
   112  	return func(w http.ResponseWriter, r *http.Request) {
   113  		ws, err := (&websocket.Upgrader{
   114  			ReadBufferSize:  256, // the only message we expect is the close message
   115  			WriteBufferSize: 1024,
   116  		}).Upgrade(w, r, nil)
   117  		if err != nil {
   118  			writer.WriteError(w, r, err)
   119  			return
   120  		}
   121  
   122  		// make channel and register it at broadcaster
   123  		c := make(EventChannel)
   124  		ww.wsClientChannels.Lock()
   125  		ww.wsClientChannels.cs = append(ww.wsClientChannels.cs, c)
   126  		ww.wsClientChannels.Unlock()
   127  
   128  		wsClosed := make(chan struct{})
   129  		go ww.readWebsocket(ws, wsClosed, watcher)
   130  
   131  		defer func() {
   132  			// attempt to close the websocket
   133  			// ignore errors as we are closing everything anyway
   134  			ww.wsWriteLock.Lock()
   135  			_ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "server context canceled"))
   136  			ww.wsWriteLock.Unlock()
   137  
   138  			_ = ws.Close()
   139  
   140  			ww.wsClientChannels.Lock()
   141  			for i, cc := range ww.wsClientChannels.cs {
   142  				if c == cc {
   143  					ww.wsClientChannels.cs[i] = ww.wsClientChannels.cs[len(ww.wsClientChannels.cs)-1]
   144  					ww.wsClientChannels.cs[len(ww.wsClientChannels.cs)-1] = nil
   145  					ww.wsClientChannels.cs = ww.wsClientChannels.cs[:len(ww.wsClientChannels.cs)-1]
   146  				}
   147  			}
   148  			ww.wsClientChannels.Unlock()
   149  			close(c)
   150  		}()
   151  
   152  		for {
   153  			select {
   154  			case <-ctx.Done():
   155  				return
   156  			case <-wsClosed:
   157  				return
   158  			case e, ok := <-c:
   159  				if !ok {
   160  					return
   161  				}
   162  
   163  				ww.wsWriteLock.Lock()
   164  				err := ws.WriteJSON(e)
   165  				ww.wsWriteLock.Unlock()
   166  
   167  				if err != nil {
   168  					return
   169  				}
   170  			}
   171  		}
   172  	}
   173  }