github.com/openziti/transport@v0.1.5/ws/connection.go (about) 1 /* 2 Copyright NetFoundry, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 https://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package ws 18 19 import ( 20 "bytes" 21 "crypto/tls" 22 "crypto/x509" 23 "errors" 24 "github.com/gorilla/websocket" 25 "github.com/openziti/transport" 26 "github.com/sirupsen/logrus" 27 "io" 28 "io/ioutil" 29 "net" 30 "sync" 31 "sync/atomic" 32 "time" 33 // _ "unsafe" // Using go:linkname requires us to import unsafe 34 ) 35 36 /** 37 * For the moment, we do not need to exploit the go:linkname mechanism(s) in order to 38 * manipulate portions of the Go runtime, but we leave this code here, commented out, 39 * in case we need to revisit. 40 41 42 // A cipherSuite is a specific combination of key agreement, cipher and MAC function. 43 type cipherSuite struct { 44 id uint16 45 // the lengths, in bytes, of the key material needed for each component. 46 keyLen int 47 macLen int 48 ivLen int 49 ka func(version uint16) 50 // flags is a bitmask of the suite* values, above. 51 flags int 52 cipher func(key, iv []byte, isRead bool) interface{} 53 mac func(version uint16, macKey []byte) 54 aead func(key, fixedNonce []byte) 55 } 56 57 //go:linkname cipherSuites crypto/tls.cipherSuites 58 var cipherSuites []*cipherSuite 59 60 const ( 61 // suiteECDHE indicates that the cipher suite involves elliptic curve 62 // Diffie-Hellman. This means that it should only be selected when the 63 // client indicates that it supports ECC with a curve and point format 64 // that we're happy with. 65 suiteECDHE = 1 << iota 66 // suiteECSign indicates that the cipher suite involves an ECDSA or 67 // EdDSA signature and therefore may only be selected when the server's 68 // certificate is ECDSA or EdDSA. If this is not set then the cipher suite 69 // is RSA based. 70 suiteECSign 71 // suiteTLS12 indicates that the cipher suite should only be advertised 72 // and accepted when using TLS 1.2. 73 suiteTLS12 74 // suiteSHA384 indicates that the cipher suite uses SHA384 as the 75 // handshake hash. 76 suiteSHA384 77 // suiteDefaultOff indicates that this cipher suite is not included by 78 // default. 79 suiteDefaultOff 80 ) 81 82 */ 83 84 // TLS 1.0 - 1.2 cipher suites supported by ziti-sdk-js 85 const ( 86 TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f 87 TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 88 ) 89 90 var ( 91 errClosing = errors.New(`Closing`) 92 ) 93 94 // safeBuffer adds thread-safety to *bytes.Buffer 95 type safeBuffer struct { 96 buf *bytes.Buffer 97 log *logrus.Entry 98 sync.Mutex 99 } 100 101 // Read reads the next len(p) bytes from the buffer or until the buffer is drained. 102 func (s *safeBuffer) Read(p []byte) (int, error) { 103 s.Lock() 104 defer s.Unlock() 105 return s.buf.Read(p) 106 } 107 108 // Write appends the contents of p to the buffer. 109 func (s *safeBuffer) Write(p []byte) (int, error) { 110 s.Lock() 111 defer s.Unlock() 112 return s.buf.Write(p) 113 } 114 115 // Len returns the number of bytes of the unread portion of the buffer. 116 func (s *safeBuffer) Len() int { 117 s.Lock() 118 defer s.Unlock() 119 return s.buf.Len() 120 } 121 122 // Reset resets the buffer to be empty. 123 func (s *safeBuffer) Reset() { 124 s.Lock() 125 s.buf.Reset() 126 s.Unlock() 127 } 128 129 // Connection wraps gorilla websocket to provide io.ReadWriteCloser 130 type Connection struct { 131 detail *transport.ConnectionDetail 132 cfg *WSConfig 133 ws *websocket.Conn 134 tlsConn *tls.Conn 135 tlsConnHandshakeComplete bool 136 log *logrus.Entry 137 rxbuf *safeBuffer 138 txbuf *safeBuffer 139 tlsrxbuf *safeBuffer 140 tlstxbuf *safeBuffer 141 done chan struct{} 142 wmutex sync.Mutex 143 rmutex sync.Mutex 144 tlswmutex sync.Mutex 145 tlsrmutex sync.Mutex 146 incoming chan transport.Connection 147 readCallDepth int32 148 writeCallDepth int32 149 } 150 151 // Read implements io.Reader by wrapping websocket messages in a buffer. 152 func (c *Connection) Read(p []byte) (n int, err error) { 153 currentDepth := atomic.AddInt32(&c.readCallDepth, 1) 154 c.log.Tracef("Read() start currentDepth[%d]", currentDepth) 155 156 if c.rxbuf.Len() == 0 { 157 var r io.Reader 158 c.rxbuf.Reset() 159 if c.tlsConnHandshakeComplete { 160 if currentDepth == 1 { 161 c.tlsrmutex.Lock() 162 defer c.tlsrmutex.Unlock() 163 } else if currentDepth == 2 { 164 c.rmutex.Lock() 165 defer c.rmutex.Unlock() 166 } 167 } else { 168 c.rmutex.Lock() 169 defer c.rmutex.Unlock() 170 } 171 select { 172 case <-c.done: 173 err = errClosing 174 default: 175 if c.tlsConnHandshakeComplete && currentDepth == 1 { 176 n, err = c.tlsConn.Read(p) 177 atomic.SwapInt32(&c.readCallDepth, (c.readCallDepth - 1)) 178 c.log.Tracef("Read() end currentDepth[%d]", currentDepth) 179 return n, err 180 } else { 181 _, r, err = c.ws.NextReader() 182 } 183 } 184 if err != nil { 185 return n, err 186 } 187 _, err = io.Copy(c.rxbuf, r) 188 if err != nil { 189 return n, err 190 } 191 } 192 193 atomic.SwapInt32(&c.readCallDepth, (c.readCallDepth - 1)) 194 195 c.log.Tracef("Read() end currentDepth[%d]", currentDepth) 196 197 return c.rxbuf.Read(p) 198 } 199 200 // Write implements io.Writer and sends binary messages only. 201 func (c *Connection) Write(p []byte) (n int, err error) { 202 return c.write(websocket.BinaryMessage, p) 203 } 204 205 // write wraps the websocket writer. 206 func (c *Connection) write(messageType int, p []byte) (n int, err error) { 207 var txbufLen int 208 currentDepth := atomic.AddInt32(&c.writeCallDepth, 1) 209 c.log.Tracef("Write() start currentDepth[%d] len[%d]", c.writeCallDepth, len(p)) 210 211 if c.tlsConnHandshakeComplete { 212 if currentDepth == 1 { 213 c.tlswmutex.Lock() 214 defer c.tlswmutex.Unlock() 215 } else if currentDepth == 2 { 216 c.wmutex.Lock() 217 defer c.wmutex.Unlock() 218 } 219 } else { 220 c.wmutex.Lock() 221 defer c.wmutex.Unlock() 222 } 223 224 select { 225 case <-c.done: 226 err = errClosing 227 default: 228 var txbufLen int 229 230 if !c.tlsConnHandshakeComplete { 231 c.tlstxbuf.Write(p) 232 txbufLen = c.tlstxbuf.Len() 233 c.log.Tracef("Write() doing TLS handshake (buffering); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p) 234 } else if currentDepth == 1 { // if at TLS level (1st level) 235 c.tlstxbuf.Write(p) 236 txbufLen = c.tlstxbuf.Len() 237 c.log.Tracef("Write() doing TLS write; currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p) 238 } else { // if at websocket level (2nd level) 239 c.txbuf.Write(p) 240 txbufLen = c.txbuf.Len() 241 c.log.Tracef("Write() doing raw write; currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, p) 242 } 243 244 err = c.ws.SetWriteDeadline(time.Now().Add(c.cfg.writeTimeout)) 245 if err == nil { 246 if !c.tlsConnHandshakeComplete { 247 m := make([]byte, txbufLen) 248 c.tlstxbuf.Read(m) 249 c.log.Tracef("Write() doing TLS handshake (to websocket); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, m) 250 err = c.ws.WriteMessage(messageType, m) 251 } else if currentDepth == 1 { 252 m := make([]byte, txbufLen) 253 c.tlstxbuf.Read(m) 254 c.log.Tracef("Write() doing TLS write (to conn); currentDepth[%d] txbufLen[%d] data[%o]", c.writeCallDepth, txbufLen, m) 255 n, err = c.tlsConn.Write(m) 256 atomic.SwapInt32(&c.writeCallDepth, (c.writeCallDepth - 1)) 257 c.log.Tracef("write() end TLS write currentDepth[%d]", c.writeCallDepth) 258 return n, err 259 } else { 260 m := make([]byte, txbufLen) 261 c.txbuf.Read(m) 262 c.log.Tracef("Write() doing raw write (to websocket); currentDepth[%d] len[%d]", c.writeCallDepth, len(m)) 263 err = c.ws.WriteMessage(messageType, m) 264 } 265 } 266 } 267 if err == nil { 268 n = txbufLen 269 } 270 atomic.SwapInt32(&c.writeCallDepth, (c.writeCallDepth - 1)) 271 c.log.Tracef("Write() end currentDepth[%d]", c.writeCallDepth) 272 273 return n, err 274 } 275 276 // Close implements io.Closer and closes the underlying connection. 277 func (c *Connection) Close() error { 278 c.rmutex.Lock() 279 c.wmutex.Lock() 280 defer func() { 281 c.rmutex.Unlock() 282 c.wmutex.Unlock() 283 }() 284 select { 285 case <-c.done: 286 return errClosing 287 default: 288 close(c.done) 289 } 290 return c.ws.Close() 291 } 292 293 // pinger sends ping messages on an interval for client keep-alive. 294 func (c *Connection) pinger() { 295 ticker := time.NewTicker(c.cfg.pingInterval) 296 defer ticker.Stop() 297 for { 298 select { 299 case <-c.done: 300 return 301 case <-ticker.C: 302 c.log.Trace("sending websocket Ping") 303 if _, err := c.write(websocket.PingMessage, []byte{}); err != nil { 304 _ = c.Close() 305 } 306 } 307 } 308 } 309 310 /** 311 * See above note re go:linkname 312 * 313 func (c *Connection) patchCipherSuites() { 314 c.log.Debug("patchCipherSuites dump: v----------------------------------------------------------") 315 for _, cipherSuite := range cipherSuites { 316 if cipherSuite.id == TLS_RSA_WITH_AES_128_CBC_SHA { 317 c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_128_CBC_SHA before: ", cipherSuite) 318 cipherSuite.flags = suiteTLS12 | suiteECDHE 319 c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_128_CBC_SHA after: ", cipherSuite) 320 } 321 if cipherSuite.id == TLS_RSA_WITH_AES_256_CBC_SHA { 322 c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_256_CBC_SHA before: ", cipherSuite) 323 cipherSuite.flags = suiteTLS12 | suiteECDHE 324 c.log.Debug("cipherSuite: TLS_RSA_WITH_AES_256_CBC_SHA after: ", cipherSuite) 325 } 326 } 327 c.log.Debug("patchCipherSuites dump: ^----------------------------------------------------------") 328 } 329 */ 330 331 // tlsHandshake wraps the websocket in a TLS server. 332 func (c *Connection) tlsHandshake() error { 333 var err error 334 var serverCertPEM []byte 335 var keyPEM []byte 336 337 //patchCipherSuites() 338 339 if serverCertPEM, err = ioutil.ReadFile(c.cfg.serverCert); err != nil { 340 c.log.Error(err) 341 _ = c.Close() 342 return err 343 } 344 345 if keyPEM, err = ioutil.ReadFile(c.cfg.key); err != nil { 346 c.log.Error(err) 347 _ = c.Close() 348 return err 349 } 350 351 cert, err := tls.X509KeyPair(serverCertPEM, keyPEM) 352 if err != nil { 353 c.log.Error(err) 354 _ = c.Close() 355 return err 356 } 357 358 caCertPool := x509.NewCertPool() 359 caCertPool.AppendCertsFromPEM(serverCertPEM) 360 361 cfg := &tls.Config{ 362 ClientCAs: caCertPool, 363 Certificates: []tls.Certificate{cert}, 364 CipherSuites: []uint16{ 365 tls.TLS_RSA_WITH_AES_128_CBC_SHA, 366 tls.TLS_RSA_WITH_AES_256_CBC_SHA, 367 }, 368 ClientAuth: tls.RequireAndVerifyClientCert, 369 MinVersion: tls.VersionTLS11, 370 PreferServerCipherSuites: true, 371 } 372 373 c.tlsConn = tls.Server(c, cfg) 374 if err = c.tlsConn.Handshake(); err != nil { 375 if err != nil { 376 c.log.Error(err) 377 _ = c.Close() 378 return err 379 } 380 } 381 382 c.tlsConnHandshakeComplete = true 383 384 c.log.Debug("TLS Handshake completed successfully") 385 386 return nil 387 } 388 389 // newSafeBuffer instantiates a new safeBuffer 390 func newSafeBuffer(log *logrus.Entry) *safeBuffer { 391 return &safeBuffer{ 392 buf: bytes.NewBuffer(nil), 393 log: log, 394 } 395 } 396 397 func (self *Connection) Detail() *transport.ConnectionDetail { 398 return self.detail 399 } 400 401 func (self *Connection) PeerCertificates() []*x509.Certificate { 402 if self.tlsConnHandshakeComplete { 403 return self.tlsConn.ConnectionState().PeerCertificates 404 } else { 405 return nil 406 } 407 } 408 409 func (self *Connection) Reader() io.Reader { 410 return self 411 } 412 413 func (self *Connection) Writer() io.Writer { 414 return self 415 } 416 417 func (self *Connection) Conn() net.Conn { 418 self.log.Debug("Conn() entered, returning TLS connection that wraps the websocket") 419 return self.tlsConn // Obtain the TLS connection that wraps the websocket 420 } 421 422 func (self *Connection) SetReadTimeout(t time.Duration) error { 423 return self.ws.UnderlyingConn().SetReadDeadline(time.Now().Add(t)) 424 } 425 426 func (self *Connection) SetWriteTimeout(t time.Duration) error { 427 return self.ws.UnderlyingConn().SetWriteDeadline(time.Now().Add(t)) 428 } 429 430 // ClearReadTimeout clears the read time for all current and future reads 431 // 432 func (self *Connection) ClearReadTimeout() error { 433 var zero time.Time 434 return self.ws.UnderlyingConn().SetReadDeadline(zero) 435 } 436 437 // ClearWriteTimeout clears the write timeout for all current and future writes 438 // 439 func (self *Connection) ClearWriteTimeout() error { 440 var zero time.Time 441 return self.ws.UnderlyingConn().SetWriteDeadline(zero) 442 } 443 444 func (self *Connection) LocalAddr() net.Addr { 445 return self.ws.UnderlyingConn().LocalAddr() 446 } 447 func (self *Connection) RemoteAddr() net.Addr { 448 return self.ws.UnderlyingConn().RemoteAddr() 449 } 450 func (self *Connection) SetDeadline(t time.Time) error { 451 return self.ws.UnderlyingConn().SetDeadline(t) 452 } 453 func (self *Connection) SetReadDeadline(t time.Time) error { 454 return self.ws.UnderlyingConn().SetReadDeadline(t) 455 } 456 func (self *Connection) SetWriteDeadline(t time.Time) error { 457 return self.ws.UnderlyingConn().SetWriteDeadline(t) 458 }