github.com/openziti/transport@v0.1.5/wss/connection.go (about)

     1  /*
     2  	Copyright NetFoundry, Inc.
     3  
     4  	Licensed under the Apache License, Version 2.0 (the "License");
     5  	you may not use this file except in compliance with the License.
     6  	You may obtain a copy of the License at
     7  
     8  	https://www.apache.org/licenses/LICENSE-2.0
     9  
    10  	Unless required by applicable law or agreed to in writing, software
    11  	distributed under the License is distributed on an "AS IS" BASIS,
    12  	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  	See the License for the specific language governing permissions and
    14  	limitations under the License.
    15  */
    16  
    17  package wss
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"errors"
    24  	"github.com/gorilla/websocket"
    25  	"github.com/openziti/transport"
    26  	"github.com/sirupsen/logrus"
    27  	"io"
    28  	"net"
    29  	"sync"
    30  	"time"
    31  )
    32  
    33  var (
    34  	errClosing = errors.New(`Closing`)
    35  )
    36  
    37  // safeBuffer adds thread-safety to *bytes.Buffer
    38  type safeBuffer struct {
    39  	buf *bytes.Buffer
    40  	log *logrus.Entry
    41  	sync.Mutex
    42  }
    43  
    44  // Read reads the next len(p) bytes from the buffer or until the buffer is drained.
    45  func (s *safeBuffer) Read(p []byte) (int, error) {
    46  	s.Lock()
    47  	defer s.Unlock()
    48  	return s.buf.Read(p)
    49  }
    50  
    51  // Write appends the contents of p to the buffer.
    52  func (s *safeBuffer) Write(p []byte) (int, error) {
    53  	s.Lock()
    54  	defer s.Unlock()
    55  	return s.buf.Write(p)
    56  }
    57  
    58  // Len returns the number of bytes of the unread portion of the buffer.
    59  func (s *safeBuffer) Len() int {
    60  	s.Lock()
    61  	defer s.Unlock()
    62  	return s.buf.Len()
    63  }
    64  
    65  // Reset resets the buffer to be empty.
    66  func (s *safeBuffer) Reset() {
    67  	s.Lock()
    68  	s.buf.Reset()
    69  	s.Unlock()
    70  }
    71  
    72  // Connection wraps gorilla websocket to provide io.ReadWriteCloser
    73  type Connection struct {
    74  	detail *transport.ConnectionDetail
    75  	cfg    *WSSConfig
    76  	ws     *websocket.Conn
    77  	log    *logrus.Entry
    78  	rxbuf  *safeBuffer
    79  	txbuf  *safeBuffer
    80  	done   chan struct{}
    81  	wmutex sync.Mutex
    82  	rmutex sync.Mutex
    83  }
    84  
    85  // Read implements io.Reader by wrapping websocket messages in a buffer.
    86  func (c *Connection) Read(p []byte) (n int, err error) {
    87  	if c.rxbuf.Len() == 0 {
    88  		var r io.Reader
    89  		c.rxbuf.Reset()
    90  		c.rmutex.Lock()
    91  		defer c.rmutex.Unlock()
    92  		select {
    93  		case <-c.done:
    94  			err = errClosing
    95  		default:
    96  			_, r, err = c.ws.NextReader()
    97  		}
    98  		if err != nil {
    99  			return n, err
   100  		}
   101  		_, err = io.Copy(c.rxbuf, r)
   102  		if err != nil {
   103  			return n, err
   104  		}
   105  	}
   106  
   107  	return c.rxbuf.Read(p)
   108  }
   109  
   110  // Write implements io.Writer and sends binary messages only.
   111  func (c *Connection) Write(p []byte) (n int, err error) {
   112  	return c.write(websocket.BinaryMessage, p)
   113  }
   114  
   115  // write wraps the websocket writer.
   116  func (c *Connection) write(messageType int, p []byte) (n int, err error) {
   117  	var txbufLen int
   118  	c.wmutex.Lock()
   119  	defer c.wmutex.Unlock()
   120  	select {
   121  	case <-c.done:
   122  		err = errClosing
   123  	default:
   124  		c.txbuf.Write(p)
   125  		txbufLen = c.txbuf.Len()
   126  		if txbufLen > 20 { // TEMP HACK:  (until I refactor the JS-SDK to accept the message section and data section in separate salvos)
   127  			err = c.ws.SetWriteDeadline(time.Now().Add(c.cfg.writeTimeout))
   128  			if err == nil {
   129  				m := make([]byte, txbufLen)
   130  				c.txbuf.Read(m)
   131  				err = c.ws.WriteMessage(messageType, m)
   132  			}
   133  		}
   134  	}
   135  	if err == nil {
   136  		n = txbufLen
   137  	}
   138  	return n, err
   139  }
   140  
   141  // Close implements io.Closer and closes the underlying connection.
   142  func (c *Connection) Close() error {
   143  	c.rmutex.Lock()
   144  	c.wmutex.Lock()
   145  	defer func() {
   146  		c.rmutex.Unlock()
   147  		c.wmutex.Unlock()
   148  	}()
   149  	select {
   150  	case <-c.done:
   151  		return errClosing
   152  	default:
   153  		close(c.done)
   154  	}
   155  	return c.ws.Close()
   156  }
   157  
   158  // pinger sends ping messages on an interval for client keep-alive.
   159  func (c *Connection) pinger() {
   160  	ticker := time.NewTicker(c.cfg.pingInterval)
   161  	defer ticker.Stop()
   162  	for {
   163  		select {
   164  		case <-c.done:
   165  			return
   166  		case <-ticker.C:
   167  			c.log.Trace("sending websocket Ping")
   168  			if _, err := c.write(websocket.PingMessage, []byte{}); err != nil {
   169  				_ = c.Close()
   170  			}
   171  		}
   172  	}
   173  }
   174  
   175  // newSafeBuffer instantiates a new safeBuffer
   176  func newSafeBuffer(log *logrus.Entry) *safeBuffer {
   177  	return &safeBuffer{
   178  		buf: bytes.NewBuffer(nil),
   179  		log: log,
   180  	}
   181  }
   182  
   183  func (self *Connection) Detail() *transport.ConnectionDetail {
   184  	return self.detail
   185  }
   186  
   187  func (self *Connection) PeerCertificates() []*x509.Certificate {
   188  	var tlsConn (*tls.Conn) = self.ws.UnderlyingConn().(*tls.Conn)
   189  	return tlsConn.ConnectionState().PeerCertificates
   190  }
   191  
   192  func (self *Connection) Reader() io.Reader {
   193  	return self
   194  }
   195  
   196  func (self *Connection) Writer() io.Writer {
   197  	return self
   198  }
   199  
   200  func (self *Connection) Conn() net.Conn {
   201  	return self.ws.UnderlyingConn() // Obtain the socket underneath the websocket
   202  
   203  }
   204  
   205  func (self *Connection) SetReadTimeout(t time.Duration) error {
   206  	return self.ws.UnderlyingConn().SetReadDeadline(time.Now().Add(t))
   207  }
   208  
   209  func (self *Connection) SetWriteTimeout(t time.Duration) error {
   210  	return self.ws.UnderlyingConn().SetWriteDeadline(time.Now().Add(t))
   211  }
   212  
   213  // ClearReadTimeout clears the read time for all current and future reads
   214  //
   215  func (self *Connection) ClearReadTimeout() error {
   216  	var zero time.Time
   217  	return self.ws.UnderlyingConn().SetReadDeadline(zero)
   218  }
   219  
   220  // ClearWriteTimeout clears the write timeout for all current and future writes
   221  //
   222  func (self *Connection) ClearWriteTimeout() error {
   223  	var zero time.Time
   224  	return self.ws.UnderlyingConn().SetWriteDeadline(zero)
   225  }