github.com/openziti/transport@v0.1.5/ws/connection.go (about)

     1  /*
     2  	Copyright NetFoundry, Inc.
     3  
     4  	Licensed under the Apache License, Version 2.0 (the "License");
     5  	you may not use this file except in compliance with the License.
     6  	You may obtain a copy of the License at
     7  
     8  	https://www.apache.org/licenses/LICENSE-2.0
     9  
    10  	Unless required by applicable law or agreed to in writing, software
    11  	distributed under the License is distributed on an "AS IS" BASIS,
    12  	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  	See the License for the specific language governing permissions and
    14  	limitations under the License.
    15  */
    16  
    17  package ws
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"errors"
    24  	"github.com/gorilla/websocket"
    25  	"github.com/openziti/transport"
    26  	"github.com/sirupsen/logrus"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  	// _ "unsafe"	// Using go:linkname requires us to import unsafe
    34  )
    35  
    36  /**
    37   *	For the moment, we do not need to exploit the go:linkname mechanism(s) in order to
    38   *	manipulate portions of the Go runtime, but we leave this code here, commented out,
    39   *	in case we need to revisit.
    40  
    41  
    42  // A cipherSuite is a specific combination of key agreement, cipher and MAC function.
    43  type cipherSuite struct {
    44  	id uint16
    45  	// the lengths, in bytes, of the key material needed for each component.
    46  	keyLen int
    47  	macLen int
    48  	ivLen  int
    49  	ka     func(version uint16)
    50  	// flags is a bitmask of the suite* values, above.
    51  	flags  int
    52  	cipher func(key, iv []byte, isRead bool) interface{}
    53  	mac    func(version uint16, macKey []byte)
    54  	aead   func(key, fixedNonce []byte)
    55  }
    56  
    57  //go:linkname cipherSuites crypto/tls.cipherSuites
    58  var cipherSuites []*cipherSuite
    59  
    60  const (
    61  	// suiteECDHE indicates that the cipher suite involves elliptic curve
    62  	// Diffie-Hellman. This means that it should only be selected when the
    63  	// client indicates that it supports ECC with a curve and point format
    64  	// that we're happy with.
    65  	suiteECDHE = 1 << iota
    66  	// suiteECSign indicates that the cipher suite involves an ECDSA or
    67  	// EdDSA signature and therefore may only be selected when the server's
    68  	// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
    69  	// is RSA based.
    70  	suiteECSign
    71  	// suiteTLS12 indicates that the cipher suite should only be advertised
    72  	// and accepted when using TLS 1.2.
    73  	suiteTLS12
    74  	// suiteSHA384 indicates that the cipher suite uses SHA384 as the
    75  	// handshake hash.
    76  	suiteSHA384
    77  	// suiteDefaultOff indicates that this cipher suite is not included by
    78  	// default.
    79  	suiteDefaultOff
    80  )
    81  
    82  */
    83  
    84  // TLS 1.0 - 1.2 cipher suites supported by ziti-sdk-js
    85  const (
    86  	TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
    87  	TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
    88  )
    89  
    90  var (
    91  	errClosing = errors.New(`Closing`)
    92  )
    93  
    94  // safeBuffer adds thread-safety to *bytes.Buffer
    95  type safeBuffer struct {
    96  	buf *bytes.Buffer
    97  	log *logrus.Entry
    98  	sync.Mutex
    99  }
   100  
   101  // Read reads the next len(p) bytes from the buffer or until the buffer is drained.
   102  func (s *safeBuffer) Read(p []byte) (int, error) {
   103  	s.Lock()
   104  	defer s.Unlock()
   105  	return s.buf.Read(p)
   106  }
   107  
   108  // Write appends the contents of p to the buffer.
   109  func (s *safeBuffer) Write(p []byte) (int, error) {
   110  	s.Lock()
   111  	defer s.Unlock()
   112  	return s.buf.Write(p)
   113  }
   114  
   115  // Len returns the number of bytes of the unread portion of the buffer.
   116  func (s *safeBuffer) Len() int {
   117  	s.Lock()
   118  	defer s.Unlock()
   119  	return s.buf.Len()
   120  }
   121  
   122  // Reset resets the buffer to be empty.
   123  func (s *safeBuffer) Reset() {
   124  	s.Lock()
   125  	s.buf.Reset()
   126  	s.Unlock()
   127  }
   128  
   129  // Connection wraps gorilla websocket to provide io.ReadWriteCloser
   130  type Connection struct {
   131  	detail                   *transport.ConnectionDetail
   132  	cfg                      *WSConfig
   133  	ws                       *websocket.Conn
   134  	tlsConn                  *tls.Conn
   135  	tlsConnHandshakeComplete bool
   136  	log                      *logrus.Entry
   137  	rxbuf                    *safeBuffer
   138  	txbuf                    *safeBuffer
   139  	tlsrxbuf                 *safeBuffer
   140  	tlstxbuf                 *safeBuffer
   141  	done                     chan struct{}
   142  	wmutex                   sync.Mutex
   143  	rmutex                   sync.Mutex
   144  	tlswmutex                sync.Mutex
   145  	tlsrmutex                sync.Mutex
   146  	incoming                 chan transport.Connection
   147  	readCallDepth            int32
   148  	writeCallDepth           int32
   149  }
   150  
   151  // Read implements io.Reader by wrapping websocket messages in a buffer.
   152  func (c *Connection) Read(p []byte) (n int, err error) {
   153  	currentDepth := atomic.AddInt32(&c.readCallDepth, 1)
   154  	c.log.Tracef("Read() start currentDepth[%d]", currentDepth)
   155  
   156  	if c.rxbuf.Len() == 0 {
   157  		var r io.Reader
   158  		c.rxbuf.Reset()
   159  		if c.tlsConnHandshakeComplete {
   160  			if currentDepth == 1 {
   161  				c.tlsrmutex.Lock()
   162  				defer c.tlsrmutex.Unlock()
   163  			} else if currentDepth == 2 {
   164  				c.rmutex.Lock()
   165  				defer c.rmutex.Unlock()
   166  			}
   167  		} else {
   168  			c.rmutex.Lock()
   169  			defer c.rmutex.Unlock()
   170  		}
   171  		select {
   172  		case <-c.done:
   173  			err = errClosing
   174  		default:
   175  			if c.tlsConnHandshakeComplete && currentDepth == 1 {
   176  				n, err = c.tlsConn.Read(p)
   177  				atomic.SwapInt32(&c.readCallDepth, (c.readCallDepth - 1))
   178  				c.log.Tracef("Read() end currentDepth[%d]", currentDepth)
   179  				return n, err
   180  			} else {
   181  				_, r, err = c.ws.NextReader()
   182  			}
   183  		}
   184  		if err != nil {
   185  			return n, err
   186  		}
   187  		_, err = io.Copy(c.rxbuf, r)
   188  		if err != nil {
   189  			return n, err
   190  		}
   191  	}
   192  
   193  	atomic.SwapInt32(&c.readCallDepth, (c.readCallDepth - 1))
   194  
   195  	c.log.Tracef("Read() end currentDepth[%d]", currentDepth)
   196  
   197  	return c.rxbuf.Read(p)
   198  }
   199  
   200  // Write implements io.Writer and sends binary messages only.
   201  func (c *Connection) Write(p []byte) (n int, err error) {
   202  	return c.write(websocket.BinaryMessage, p)
   203  }
   204  
   205  // write wraps the websocket writer.
   206  func (c *Connection) write(messageType int, p []byte) (n int, err error) {
   207  	var txbufLen int
   208  	currentDepth := atomic.AddInt32(&c.writeCallDepth, 1)
   209  	c.log.Tracef("Write() start currentDepth[%d] len[%d]", c.writeCallDepth, len(p))
   210  
   211  	if c.tlsConnHandshakeComplete {
   212  		if currentDepth == 1 {
   213  			c.tlswmutex.Lock()
   214  			defer c.tlswmutex.Unlock()
   215  		} else if currentDepth == 2 {
   216  			c.wmutex.Lock()
   217  			defer c.wmutex.Unlock()
   218  		}
   219  	} else {
   220  		c.wmutex.Lock()
   221  		defer c.wmutex.Unlock()
   222  	}
   223  
   224  	select {
   225  	case <-c.done:
   226  		err = errClosing
   227  	default:
   228  		var txbufLen int
   229  
   230  		if !c.tlsConnHandshakeComplete {
   231  			c.tlstxbuf.Write(p)
   232  			txbufLen = c.tlstxbuf.Len()
   233  			c.log.Tracef("Write() doing TLS handshake (buffering); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p)
   234  		} else if currentDepth == 1 { // if at TLS level (1st level)
   235  			c.tlstxbuf.Write(p)
   236  			txbufLen = c.tlstxbuf.Len()
   237  			c.log.Tracef("Write() doing TLS write; currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p)
   238  		} else { // if at websocket level (2nd level)
   239  			c.txbuf.Write(p)
   240  			txbufLen = c.txbuf.Len()
   241  			c.log.Tracef("Write() doing raw write; currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p)
   242  		}
   243  
   244  		err = c.ws.SetWriteDeadline(time.Now().Add(c.cfg.writeTimeout))
   245  		if err == nil {
   246  			if !c.tlsConnHandshakeComplete {
   247  				m := make([]byte, txbufLen)
   248  				c.tlstxbuf.Read(m)
   249  				c.log.Tracef("Write() doing TLS handshake (to websocket); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, m)
   250  				err = c.ws.WriteMessage(messageType, m)
   251  			} else if currentDepth == 1 {
   252  				m := make([]byte, txbufLen)
   253  				c.tlstxbuf.Read(m)
   254  				c.log.Tracef("Write() doing TLS write (to conn); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, m)
   255  				n, err = c.tlsConn.Write(m)
   256  				atomic.SwapInt32(&c.writeCallDepth, (c.writeCallDepth - 1))
   257  				c.log.Tracef("write() end TLS write currentDepth[%d]", c.writeCallDepth)
   258  				return n, err
   259  			} else {
   260  				m := make([]byte, txbufLen)
   261  				c.txbuf.Read(m)
   262  				c.log.Tracef("Write() doing raw write (to websocket); currentDepth[%d] len[%d]", c.writeCallDepth, len(m))
   263  				err = c.ws.WriteMessage(messageType, m)
   264  			}
   265  		}
   266  	}
   267  	if err == nil {
   268  		n = txbufLen
   269  	}
   270  	atomic.SwapInt32(&c.writeCallDepth, (c.writeCallDepth - 1))
   271  	c.log.Tracef("Write() end currentDepth[%d]", c.writeCallDepth)
   272  
   273  	return n, err
   274  }
   275  
   276  // Close implements io.Closer and closes the underlying connection.
   277  func (c *Connection) Close() error {
   278  	c.rmutex.Lock()
   279  	c.wmutex.Lock()
   280  	defer func() {
   281  		c.rmutex.Unlock()
   282  		c.wmutex.Unlock()
   283  	}()
   284  	select {
   285  	case <-c.done:
   286  		return errClosing
   287  	default:
   288  		close(c.done)
   289  	}
   290  	return c.ws.Close()
   291  }
   292  
   293  // pinger sends ping messages on an interval for client keep-alive.
   294  func (c *Connection) pinger() {
   295  	ticker := time.NewTicker(c.cfg.pingInterval)
   296  	defer ticker.Stop()
   297  	for {
   298  		select {
   299  		case <-c.done:
   300  			return
   301  		case <-ticker.C:
   302  			c.log.Trace("sending websocket Ping")
   303  			if _, err := c.write(websocket.PingMessage, []byte{}); err != nil {
   304  				_ = c.Close()
   305  			}
   306  		}
   307  	}
   308  }
   309  
   310  /**
   311   *	See above note re go:linkname
   312   *
   313  func (c *Connection) patchCipherSuites() {
   314  	c.log.Debug("patchCipherSuites dump: v----------------------------------------------------------")
   315  	for _, cipherSuite := range cipherSuites {
   316  		if cipherSuite.id == TLS_RSA_WITH_AES_128_CBC_SHA {
   317  			c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_128_CBC_SHA before: ", cipherSuite)
   318  			cipherSuite.flags = suiteTLS12 | suiteECDHE
   319  			c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_128_CBC_SHA after: ", cipherSuite)
   320  		}
   321  		if cipherSuite.id == TLS_RSA_WITH_AES_256_CBC_SHA {
   322  			c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_256_CBC_SHA before: ", cipherSuite)
   323  			cipherSuite.flags = suiteTLS12 | suiteECDHE
   324  			c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_256_CBC_SHA after: ", cipherSuite)
   325  		}
   326  	}
   327  	c.log.Debug("patchCipherSuites dump: ^----------------------------------------------------------")
   328  }
   329  */
   330  
   331  // tlsHandshake wraps the websocket in a TLS server.
   332  func (c *Connection) tlsHandshake() error {
   333  	var err error
   334  	var serverCertPEM []byte
   335  	var keyPEM []byte
   336  
   337  	//patchCipherSuites()
   338  
   339  	if serverCertPEM, err = ioutil.ReadFile(c.cfg.serverCert); err != nil {
   340  		c.log.Error(err)
   341  		_ = c.Close()
   342  		return err
   343  	}
   344  
   345  	if keyPEM, err = ioutil.ReadFile(c.cfg.key); err != nil {
   346  		c.log.Error(err)
   347  		_ = c.Close()
   348  		return err
   349  	}
   350  
   351  	cert, err := tls.X509KeyPair(serverCertPEM, keyPEM)
   352  	if err != nil {
   353  		c.log.Error(err)
   354  		_ = c.Close()
   355  		return err
   356  	}
   357  
   358  	caCertPool := x509.NewCertPool()
   359  	caCertPool.AppendCertsFromPEM(serverCertPEM)
   360  
   361  	cfg := &tls.Config{
   362  		ClientCAs:    caCertPool,
   363  		Certificates: []tls.Certificate{cert},
   364  		CipherSuites: []uint16{
   365  			tls.TLS_RSA_WITH_AES_128_CBC_SHA,
   366  			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
   367  		},
   368  		ClientAuth:               tls.RequireAndVerifyClientCert,
   369  		MinVersion:               tls.VersionTLS11,
   370  		PreferServerCipherSuites: true,
   371  	}
   372  
   373  	c.tlsConn = tls.Server(c, cfg)
   374  	if err = c.tlsConn.Handshake(); err != nil {
   375  		if err != nil {
   376  			c.log.Error(err)
   377  			_ = c.Close()
   378  			return err
   379  		}
   380  	}
   381  
   382  	c.tlsConnHandshakeComplete = true
   383  
   384  	c.log.Debug("TLS Handshake completed successfully")
   385  
   386  	return nil
   387  }
   388  
   389  // newSafeBuffer instantiates a new safeBuffer
   390  func newSafeBuffer(log *logrus.Entry) *safeBuffer {
   391  	return &safeBuffer{
   392  		buf: bytes.NewBuffer(nil),
   393  		log: log,
   394  	}
   395  }
   396  
   397  func (self *Connection) Detail() *transport.ConnectionDetail {
   398  	return self.detail
   399  }
   400  
   401  func (self *Connection) PeerCertificates() []*x509.Certificate {
   402  	if self.tlsConnHandshakeComplete {
   403  		return self.tlsConn.ConnectionState().PeerCertificates
   404  	} else {
   405  		return nil
   406  	}
   407  }
   408  
   409  func (self *Connection) Reader() io.Reader {
   410  	return self
   411  }
   412  
   413  func (self *Connection) Writer() io.Writer {
   414  	return self
   415  }
   416  
   417  func (self *Connection) Conn() net.Conn {
   418  	self.log.Debug("Conn() entered, returning TLS connection that wraps the websocket")
   419  	return self.tlsConn // Obtain the TLS connection that wraps the websocket
   420  }
   421  
   422  func (self *Connection) SetReadTimeout(t time.Duration) error {
   423  	return self.ws.UnderlyingConn().SetReadDeadline(time.Now().Add(t))
   424  }
   425  
   426  func (self *Connection) SetWriteTimeout(t time.Duration) error {
   427  	return self.ws.UnderlyingConn().SetWriteDeadline(time.Now().Add(t))
   428  }
   429  
   430  // ClearReadTimeout clears the read time for all current and future reads
   431  //
   432  func (self *Connection) ClearReadTimeout() error {
   433  	var zero time.Time
   434  	return self.ws.UnderlyingConn().SetReadDeadline(zero)
   435  }
   436  
   437  // ClearWriteTimeout clears the write timeout for all current and future writes
   438  //
   439  func (self *Connection) ClearWriteTimeout() error {
   440  	var zero time.Time
   441  	return self.ws.UnderlyingConn().SetWriteDeadline(zero)
   442  }
   443  
   444  func (self *Connection) LocalAddr() net.Addr {
   445  	return self.ws.UnderlyingConn().LocalAddr()
   446  }
   447  func (self *Connection) RemoteAddr() net.Addr {
   448  	return self.ws.UnderlyingConn().RemoteAddr()
   449  }
   450  func (self *Connection) SetDeadline(t time.Time) error {
   451  	return self.ws.UnderlyingConn().SetDeadline(t)
   452  }
   453  func (self *Connection) SetReadDeadline(t time.Time) error {
   454  	return self.ws.UnderlyingConn().SetReadDeadline(t)
   455  }
   456  func (self *Connection) SetWriteDeadline(t time.Time) error {
   457  	return self.ws.UnderlyingConn().SetWriteDeadline(t)
   458  }