github.com/openziti/transport@v0.1.5/wss/connection.go (about) 1 /* 2 Copyright NetFoundry, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 https://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package wss 18 19 import ( 20 "bytes" 21 "crypto/tls" 22 "crypto/x509" 23 "errors" 24 "github.com/gorilla/websocket" 25 "github.com/openziti/transport" 26 "github.com/sirupsen/logrus" 27 "io" 28 "net" 29 "sync" 30 "time" 31 ) 32 33 var ( 34 errClosing = errors.New(`Closing`) 35 ) 36 37 // safeBuffer adds thread-safety to *bytes.Buffer 38 type safeBuffer struct { 39 buf *bytes.Buffer 40 log *logrus.Entry 41 sync.Mutex 42 } 43 44 // Read reads the next len(p) bytes from the buffer or until the buffer is drained. 45 func (s *safeBuffer) Read(p []byte) (int, error) { 46 s.Lock() 47 defer s.Unlock() 48 return s.buf.Read(p) 49 } 50 51 // Write appends the contents of p to the buffer. 52 func (s *safeBuffer) Write(p []byte) (int, error) { 53 s.Lock() 54 defer s.Unlock() 55 return s.buf.Write(p) 56 } 57 58 // Len returns the number of bytes of the unread portion of the buffer. 59 func (s *safeBuffer) Len() int { 60 s.Lock() 61 defer s.Unlock() 62 return s.buf.Len() 63 } 64 65 // Reset resets the buffer to be empty. 66 func (s *safeBuffer) Reset() { 67 s.Lock() 68 s.buf.Reset() 69 s.Unlock() 70 } 71 72 // Connection wraps gorilla websocket to provide io.ReadWriteCloser 73 type Connection struct { 74 detail *transport.ConnectionDetail 75 cfg *WSSConfig 76 ws *websocket.Conn 77 log *logrus.Entry 78 rxbuf *safeBuffer 79 txbuf *safeBuffer 80 done chan struct{} 81 wmutex sync.Mutex 82 rmutex sync.Mutex 83 } 84 85 // Read implements io.Reader by wrapping websocket messages in a buffer. 86 func (c *Connection) Read(p []byte) (n int, err error) { 87 if c.rxbuf.Len() == 0 { 88 var r io.Reader 89 c.rxbuf.Reset() 90 c.rmutex.Lock() 91 defer c.rmutex.Unlock() 92 select { 93 case <-c.done: 94 err = errClosing 95 default: 96 _, r, err = c.ws.NextReader() 97 } 98 if err != nil { 99 return n, err 100 } 101 _, err = io.Copy(c.rxbuf, r) 102 if err != nil { 103 return n, err 104 } 105 } 106 107 return c.rxbuf.Read(p) 108 } 109 110 // Write implements io.Writer and sends binary messages only. 111 func (c *Connection) Write(p []byte) (n int, err error) { 112 return c.write(websocket.BinaryMessage, p) 113 } 114 115 // write wraps the websocket writer. 116 func (c *Connection) write(messageType int, p []byte) (n int, err error) { 117 var txbufLen int 118 c.wmutex.Lock() 119 defer c.wmutex.Unlock() 120 select { 121 case <-c.done: 122 err = errClosing 123 default: 124 c.txbuf.Write(p) 125 txbufLen = c.txbuf.Len() 126 if txbufLen > 20 { // TEMP HACK: (until I refactor the JS-SDK to accept the message section and data section in separate salvos) 127 err = c.ws.SetWriteDeadline(time.Now().Add(c.cfg.writeTimeout)) 128 if err == nil { 129 m := make([]byte, txbufLen) 130 c.txbuf.Read(m) 131 err = c.ws.WriteMessage(messageType, m) 132 } 133 } 134 } 135 if err == nil { 136 n = txbufLen 137 } 138 return n, err 139 } 140 141 // Close implements io.Closer and closes the underlying connection. 142 func (c *Connection) Close() error { 143 c.rmutex.Lock() 144 c.wmutex.Lock() 145 defer func() { 146 c.rmutex.Unlock() 147 c.wmutex.Unlock() 148 }() 149 select { 150 case <-c.done: 151 return errClosing 152 default: 153 close(c.done) 154 } 155 return c.ws.Close() 156 } 157 158 // pinger sends ping messages on an interval for client keep-alive. 159 func (c *Connection) pinger() { 160 ticker := time.NewTicker(c.cfg.pingInterval) 161 defer ticker.Stop() 162 for { 163 select { 164 case <-c.done: 165 return 166 case <-ticker.C: 167 c.log.Trace("sending websocket Ping") 168 if _, err := c.write(websocket.PingMessage, []byte{}); err != nil { 169 _ = c.Close() 170 } 171 } 172 } 173 } 174 175 // newSafeBuffer instantiates a new safeBuffer 176 func newSafeBuffer(log *logrus.Entry) *safeBuffer { 177 return &safeBuffer{ 178 buf: bytes.NewBuffer(nil), 179 log: log, 180 } 181 } 182 183 func (self *Connection) Detail() *transport.ConnectionDetail { 184 return self.detail 185 } 186 187 func (self *Connection) PeerCertificates() []*x509.Certificate { 188 var tlsConn (*tls.Conn) = self.ws.UnderlyingConn().(*tls.Conn) 189 return tlsConn.ConnectionState().PeerCertificates 190 } 191 192 func (self *Connection) Reader() io.Reader { 193 return self 194 } 195 196 func (self *Connection) Writer() io.Writer { 197 return self 198 } 199 200 func (self *Connection) Conn() net.Conn { 201 return self.ws.UnderlyingConn() // Obtain the socket underneath the websocket 202 203 } 204 205 func (self *Connection) SetReadTimeout(t time.Duration) error { 206 return self.ws.UnderlyingConn().SetReadDeadline(time.Now().Add(t)) 207 } 208 209 func (self *Connection) SetWriteTimeout(t time.Duration) error { 210 return self.ws.UnderlyingConn().SetWriteDeadline(time.Now().Add(t)) 211 } 212 213 // ClearReadTimeout clears the read time for all current and future reads 214 // 215 func (self *Connection) ClearReadTimeout() error { 216 var zero time.Time 217 return self.ws.UnderlyingConn().SetReadDeadline(zero) 218 } 219 220 // ClearWriteTimeout clears the write timeout for all current and future writes 221 // 222 func (self *Connection) ClearWriteTimeout() error { 223 var zero time.Time 224 return self.ws.UnderlyingConn().SetWriteDeadline(zero) 225 }