github.com/cantara/gober@v0.18.8/websocket/server.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	log "github.com/cantara/bragi/sbragi"
     8  	"github.com/gin-gonic/gin"
     9  	"github.com/gobwas/ws"
    10  	jsoniter "github.com/json-iterator/go"
    11  	"io"
    12  	"net"
    13  	"reflect"
    14  	"sync"
    15  	"time"
    16  )
    17  
    18  var json = jsoniter.ConfigDefault
    19  
    20  var BufferSize = 100
    21  
    22  func Serve[T any](r *gin.RouterGroup, path string, acceptFunc func(c *gin.Context) bool, wsfunc WSHandler[T]) {
    23  	r.GET(path, func(c *gin.Context) {
    24  		if acceptFunc != nil && !acceptFunc(c) {
    25  			return //Could be smart to have some check of weather or not the statuscode code has been set.
    26  		}
    27  		conn, _, _, err := ws.UpgradeHTTP(c.Request, c.Writer)
    28  		if err != nil {
    29  			log.WithError(err).Fatal("while accepting websocket", "request", c.Request)
    30  		}
    31  		ctx, cancel := context.WithCancel(c.Request.Context())
    32  		defer cancel()
    33  		clientClosed := false
    34  		reader := make(chan T, BufferSize)
    35  		writer := make(chan Write[T], BufferSize)
    36  		tick := time.Second * 50
    37  		sucker := webSucker[T]{
    38  			pingTimout: tick,
    39  			pingTicker: time.NewTicker(tick),
    40  			writeLock:  sync.Mutex{},
    41  			conn:       conn,
    42  		}
    43  		/*
    44  			connWriter := make(chan []byte, 1)
    45  			go func() {
    46  				defer func() {
    47  					if !clientClosed {
    48  						err = ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "writer closed")))
    49  						log.WithError(err).Info("writing client websocket close frame")
    50  					}
    51  					log.WithError(conn.Close()).Info("closing client net conn")
    52  				}()
    53  				tickD := time.Second * 50
    54  				tkr := time.NewTicker(tickD)
    55  				defer tkr.Stop()
    56  				for {
    57  					select {
    58  					case write, ok := <-connWriter:
    59  						if !ok {
    60  							return
    61  						}
    62  						n, err := conn.Write(write)
    63  						total := n
    64  						for err == nil && total < len(write) {
    65  							n, err = conn.Write(write[total:])
    66  							total += n
    67  						}
    68  						if err != nil {
    69  							log.WithError(err).Error("while writing to websocket", "path", path, "type", reflect.TypeOf(write).String(), "data", write) // This could end up logging person sensitive data.
    70  							return
    71  						}
    72  						tkr.Reset(tickD)
    73  					case <-tkr.C:
    74  						connWriter <- ws.CompiledPing
    75  						log.WithError(err).Info("wrote ping from server")
    76  					}
    77  				}
    78  			}()
    79  		*/
    80  		go func() {
    81  			defer func() {
    82  				if !clientClosed {
    83  					err = ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "writer closed")))
    84  					log.WithError(err).Info("writing client websocket close frame")
    85  				}
    86  				log.WithError(conn.Close()).Info("closing client net conn")
    87  			}()
    88  			for {
    89  				select {
    90  				case <-ctx.Done():
    91  					return
    92  				case write, ok := <-writer:
    93  					if !ok {
    94  						return
    95  					}
    96  					//err := WriteWebsocket[T](connWriter, write)
    97  					err := sucker.Write(write)
    98  					if err != nil {
    99  						if errors.Is(err, net.ErrClosed) {
   100  							clientClosed = true
   101  							cancel()
   102  							return
   103  						}
   104  						log.WithError(err).Error("while writing to websocket", "path", path, "request", c.Request, "type", reflect.TypeOf(write).String()) // This could end up logging person sensitive data.
   105  						return
   106  					}
   107  				case <-sucker.pingTicker.C:
   108  					err = sucker.Ping()
   109  					if err != nil {
   110  						if errors.Is(err, ErrNoErrorHandled) {
   111  							//log.Debug("no ping already waiting for pong from client")
   112  							continue
   113  						}
   114  						if errors.Is(err, net.ErrClosed) {
   115  							clientClosed = true
   116  							cancel()
   117  							return
   118  						}
   119  					}
   120  					log.WithError(err).Debug("wrote ping from server")
   121  				}
   122  			}
   123  		}()
   124  		go func() {
   125  			defer close(reader)
   126  			var read T
   127  			var err error
   128  			for {
   129  				select {
   130  				case <-ctx.Done():
   131  					return
   132  				default:
   133  					//read, err = ReadWebsocket[T](conn, connWriter)
   134  					read, err = sucker.Read()
   135  					if err != nil {
   136  						if errors.Is(err, ErrNoErrorHandled) {
   137  							continue
   138  						}
   139  						if errors.Is(err, ErrNotImplemented) {
   140  							log.WithError(err).Warning("continuing after packet is discarded")
   141  							continue
   142  						}
   143  						if errors.Is(err, net.ErrClosed) {
   144  							clientClosed = true
   145  							cancel()
   146  							return
   147  						}
   148  						if errors.Is(err, io.EOF) {
   149  							clientClosed = true
   150  							cancel()
   151  							log.Info("websocket is closed, server closing...") //This works, but gave a wrong impression, changed slightly
   152  							return
   153  						}
   154  						log.WithError(err).Error("while server reading from websocket", "path", path, "request", c.Request, "type", reflect.TypeOf(read).String()) // This could end up logging person sensitive data.
   155  						return
   156  					}
   157  					reader <- read
   158  				}
   159  			}
   160  		}()
   161  		wsfunc(reader, writer, c.Params, ctx)
   162  	})
   163  }
   164  
   165  type webSucker[T any] struct {
   166  	pingTimout time.Duration
   167  	pingTicker *time.Ticker
   168  	pingLock   sync.Mutex
   169  	writeLock  sync.Mutex
   170  	conn       net.Conn
   171  }
   172  
   173  func (sucker *webSucker[T]) Ping() (err error) {
   174  	if !sucker.pingLock.TryLock() {
   175  		return ErrNoErrorHandled
   176  	}
   177  	return sucker.WriteConn(ws.CompiledPing)
   178  }
   179  
   180  func (sucker *webSucker[T]) WriteConn(write []byte) (err error) {
   181  	defer sucker.pingTicker.Reset(sucker.pingTimout)
   182  	sucker.writeLock.Lock()
   183  	defer sucker.writeLock.Unlock()
   184  	var n int
   185  	n, err = sucker.conn.Write(write)
   186  	total := n
   187  	for err == nil && total < len(write) {
   188  		n, err = sucker.conn.Write(write[total:])
   189  		total += n
   190  	}
   191  	return
   192  }
   193  
   194  func (sucker *webSucker[T]) Write(write Write[T]) (err error) {
   195  	defer func() {
   196  		if write.Err != nil {
   197  			close(write.Err)
   198  		}
   199  	}()
   200  	payload, err := json.Marshal(write.Data)
   201  	if err != nil {
   202  		if write.Err != nil {
   203  			write.Err <- err
   204  		}
   205  		return err
   206  	}
   207  	var frame []byte
   208  	frame, err = ws.CompileFrame(ws.NewTextFrame(payload))
   209  	if err != nil {
   210  		return
   211  	}
   212  	err = sucker.WriteConn(frame)
   213  	/*
   214  		err = sucker.WriteConn(append(websocketHeaderBytes(ws.Header{
   215  			Fin:    true,
   216  			Rsv:    0,
   217  			OpCode: ws.OpText,
   218  			Masked: false,
   219  			Mask:   [4]byte{},
   220  			Length: int64(len(payload)),
   221  		}), payload...))
   222  	*/
   223  	if err != nil {
   224  		if write.Err != nil {
   225  			write.Err <- err
   226  		}
   227  		return err
   228  	}
   229  	return
   230  }
   231  
   232  func (sucker *webSucker[T]) Read() (out T, err error) {
   233  	//defer sucker.pingTicker.Reset(sucker.pingTimout)
   234  	header, err := ws.ReadHeader(sucker.conn)
   235  	if err != nil {
   236  		if errors.Is(err, net.ErrClosed) {
   237  			err = io.EOF
   238  			return
   239  		}
   240  		return
   241  	}
   242  	log.Trace("packet received", "type", packetTypeToString(header.OpCode))
   243  	sucker.pingTicker.Stop()
   244  	defer sucker.pingTicker.Reset(sucker.pingTimout)
   245  	if header.OpCode == ws.OpClose {
   246  		err = io.EOF
   247  		return
   248  	}
   249  	if header.OpCode == ws.OpPing {
   250  		log.Debug("ping received, ponging...")
   251  		payload := make([]byte, header.Length)
   252  		_, err = io.ReadFull(sucker.conn, payload)
   253  		if err != nil {
   254  			return
   255  		}
   256  		/*
   257  			var frame []byte
   258  			frame, err = ws.CompileFrame(ws.NewPongFrame(payload))
   259  			if err != nil {
   260  				return
   261  			}
   262  		*/
   263  		err = sucker.WriteConn(ws.CompiledPong)
   264  		/*
   265  			err = sucker.WriteConn(append(websocketHeaderBytes(ws.Header{
   266  				Fin:    true,
   267  				Rsv:    0,
   268  				OpCode: ws.OpPong,
   269  				Masked: false,
   270  				Mask:   [4]byte{},
   271  				Length: header.Length,
   272  			}), payload...))
   273  		*/
   274  		log.WithError(err).Trace("while ponging")
   275  		err = ErrNoErrorHandled
   276  		return
   277  	}
   278  
   279  	/*
   280  		1. Should verify against outstanding ping TODO
   281  		2. Should ignore if no outstanding ping
   282  	*/
   283  	if header.OpCode == ws.OpPong {
   284  		log.Debug("pong received")
   285  		sucker.pingLock.Unlock()
   286  		if header.Length == 0 {
   287  			err = ErrNoErrorHandled
   288  			return
   289  		}
   290  		_, err = io.CopyN(io.Discard, sucker.conn, header.Length)
   291  		err = ErrNoErrorHandled
   292  		return
   293  	}
   294  
   295  	if header.OpCode == ws.OpContinuation {
   296  		_, err = io.CopyN(io.Discard, sucker.conn, header.Length)
   297  		err = ErrNotImplemented
   298  		return
   299  	}
   300  
   301  	if header.OpCode == ws.OpBinary {
   302  		_, err = io.CopyN(io.Discard, sucker.conn, header.Length)
   303  		err = ErrNotImplemented
   304  		return
   305  	}
   306  
   307  	payload := make([]byte, header.Length)
   308  	_, err = io.ReadFull(sucker.conn, payload)
   309  	if err != nil {
   310  		if errors.Is(err, net.ErrClosed) {
   311  			err = io.EOF
   312  			return
   313  		}
   314  		return
   315  	}
   316  	if header.Masked {
   317  		ws.Cipher(payload, header.Mask, 0)
   318  	}
   319  	err = json.Unmarshal(payload, &out)
   320  	return
   321  }
   322  
   323  /*
   324  func ReadWebsocket[T any](conn io.Reader, writer chan<- []byte) (out T, err error) {
   325  	header, err := ws.ReadHeader(conn)
   326  	if err != nil {
   327  		if errors.Is(err, net.ErrClosed) {
   328  			err = io.EOF
   329  			return
   330  		}
   331  		return
   332  	}
   333  	if header.OpCode == ws.OpClose {
   334  		err = io.EOF
   335  		return
   336  	}
   337  	if header.OpCode == ws.OpPing {
   338  		log.Info("ping received, ponging...")
   339  		//Could also use ws.NewPingFrame(body)
   340  		payload := make([]byte, header.Length)
   341  		_, err = io.ReadFull(conn, payload)
   342  		if err != nil {
   343  			return
   344  		}
   345  
   346  		writer <- append(websocketHeaderBytes(ws.Header{ //This can write to a closed channel
   347  			Fin:    true,
   348  			Rsv:    0,
   349  			OpCode: ws.OpPong,
   350  			Masked: false,
   351  			Mask:   [4]byte{},
   352  			Length: header.Length,
   353  		}), payload...)
   354  		/*
   355  			err = ws.WriteHeader(conn, ws.Header{
   356  				Fin:    true,
   357  				Rsv:    0,
   358  				OpCode: ws.OpPong,
   359  				Masked: false,
   360  				Mask:   [4]byte{},
   361  				Length: header.Length,
   362  			})
   363  			if err != nil {
   364  				return
   365  			}
   366  			_, err = io.CopyN(conn, conn, header.Length)
   367  */ /*
   368  		err = ErrNoErrorHandled
   369  		return
   370  	}
   371  
   372  	/*
   373  		1. Should verify against outstanding ping TODO
   374  		2. Should ignore if no outstanding ping
   375  */ /*
   376  	if header.OpCode == ws.OpPong {
   377  		log.Info("pong received")
   378  		if header.Length == 0 {
   379  			err = ErrNoErrorHandled
   380  			return
   381  		}
   382  		_, err = io.CopyN(io.Discard, conn, header.Length)
   383  		err = ErrNoErrorHandled
   384  		return
   385  	}
   386  
   387  	if header.OpCode == ws.OpContinuation {
   388  		_, err = io.CopyN(io.Discard, conn, header.Length)
   389  		err = ErrNotImplemented
   390  		return
   391  	}
   392  
   393  	if header.OpCode == ws.OpBinary {
   394  		_, err = io.CopyN(io.Discard, conn, header.Length)
   395  		err = ErrNotImplemented
   396  		return
   397  	}
   398  
   399  	payload := make([]byte, header.Length)
   400  	_, err = io.ReadFull(conn, payload) //Could be an idea to change this to ReadAll to not have EOF errors. Or silence them ourselves
   401  	/*
   402  		total, err := conn.Read(payload)
   403  		var n int
   404  		for err == nil && total < int(header.Length) {
   405  			n, err = conn.Read(payload[total:])
   406  			total += n
   407  		}
   408  */ /*
   409  	if err != nil {
   410  		if errors.Is(err, net.ErrClosed) {
   411  			err = io.EOF
   412  			return
   413  		}
   414  		return
   415  	}
   416  	if header.Masked {
   417  		ws.Cipher(payload, header.Mask, 0)
   418  	}
   419  	err = json.Unmarshal(payload, &out)
   420  	return
   421  }
   422  
   423  
   424  func WriteWebsocket[T any](writer chan<- []byte, write Write[T]) error {
   425  	defer func() {
   426  		if write.Err != nil {
   427  			close(write.Err)
   428  		}
   429  	}()
   430  	payload, err := json.Marshal(write.Data)
   431  	if err != nil {
   432  		if write.Err != nil {
   433  			write.Err <- err
   434  		}
   435  		return err
   436  	}
   437  	writer <- append(websocketHeaderBytes(ws.Header{
   438  		Fin:    true,
   439  		Rsv:    0,
   440  		OpCode: ws.OpText,
   441  		Masked: false,
   442  		Mask:   [4]byte{},
   443  		Length: int64(len(payload)),
   444  	}), payload...)
   445  	/*
   446  		err = ws.WriteFrame(conn, ws.Frame{
   447  			Header: ws.Header{
   448  				Fin:    true,
   449  				Rsv:    0,
   450  				OpCode: ws.OpText,
   451  				Masked: false,
   452  				Mask:   [4]byte{},
   453  				Length: int64(len(payload)),
   454  			},
   455  			Payload: payload,
   456  		})
   457  		//_, err = conn.Write(payload)
   458  		if err != nil {
   459  			if write.Err != nil {
   460  				write.Err <- err
   461  			}
   462  			return err
   463  		}
   464  */ /*
   465  	return nil
   466  }
   467  */
   468  
   469  func websocketHeaderBytes(h ws.Header) []byte {
   470  	bts := make([]byte, ws.MaxHeaderSize)
   471  
   472  	if h.Fin {
   473  		bts[0] |= bit0
   474  	}
   475  	bts[0] |= h.Rsv << 4
   476  	bts[0] |= byte(h.OpCode)
   477  
   478  	var n int
   479  	switch {
   480  	case h.Length <= len7:
   481  		bts[1] = byte(h.Length)
   482  		n = 2
   483  
   484  	case h.Length <= len16:
   485  		bts[1] = 126
   486  		binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
   487  		n = 4
   488  
   489  	case h.Length <= len64:
   490  		bts[1] = 127
   491  		binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
   492  		n = 10
   493  
   494  	default:
   495  		log.WithError(ws.ErrHeaderLengthUnexpected).Fatal("while creating websocket header bytes")
   496  	}
   497  
   498  	if h.Masked {
   499  		bts[1] |= bit0
   500  		n += copy(bts[n:], h.Mask[:])
   501  	}
   502  	return bts[:n]
   503  }
   504  
   505  type WSHandler[T any] func(<-chan T, chan<- Write[T], gin.Params, context.Context)
   506  
   507  var ErrNotImplemented = errors.New("operation not implemented")
   508  var ErrNoErrorHandled = errors.New("handled")
   509  
   510  const (
   511  	bit0 = 0x80
   512  
   513  	len7  = int64(125)
   514  	len16 = int64(^(uint16(0)))
   515  	len64 = int64(^(uint64(0)) >> 1)
   516  )
   517  
   518  func packetTypeToString(code ws.OpCode) string {
   519  	switch code {
   520  	case ws.OpText:
   521  		return "text"
   522  	case ws.OpBinary:
   523  		return "binary"
   524  	case ws.OpClose:
   525  		return "close"
   526  	case ws.OpPing:
   527  		return "ping"
   528  	case ws.OpPong:
   529  		return "pong"
   530  	case ws.OpContinuation:
   531  		return "continuation"
   532  	}
   533  	return ""
   534  }