github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/conn.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 // This file implements a protocol of hybi draft. 8 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 9 10 import ( 11 "crypto/rand" 12 "encoding/binary" 13 "errors" 14 "io" 15 "net" 16 "sync" 17 "time" 18 19 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 20 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 21 "github.com/Asutorufa/yuhaiin/pkg/utils/relay" 22 ) 23 24 // Conn represents a WebSocket connection. 25 // 26 // Multiple goroutines may invoke methods on a Conn simultaneously. 27 type Conn struct { 28 IsServer bool 29 30 LastPayloadType opcode 31 PayloadType opcode 32 33 readHeaderBuf [8]byte 34 writeHeaderBuf [8]byte 35 36 rio sync.Mutex 37 wio sync.Mutex 38 39 Frame io.ReadCloser 40 41 RawConn net.Conn 42 } 43 44 // newConn creates a new WebSocket connection speaking hybi draft protocol. 45 func newConn(rwc net.Conn, isServer bool) *Conn { 46 return &Conn{ 47 IsServer: isServer, 48 RawConn: rwc, 49 PayloadType: opBinary, 50 } 51 } 52 53 // Read implements the io.Reader interface: 54 // it reads data of a frame from the WebSocket connection. 55 // if msg is not large enough for the frame data, it fills the msg and next Read 56 // will read the rest of the frame data. 57 // it reads Text frame or Binary frame. 58 func (ws *Conn) Read(msg []byte) (n int, err error) { 59 ws.rio.Lock() 60 defer ws.rio.Unlock() 61 62 for { 63 if ws.Frame == nil { 64 _, ws.Frame, err = ws.nextFrameReader() 65 if err != nil { 66 return 0, err 67 } 68 } 69 70 n, err = ws.Frame.Read(msg) 71 if err == nil || n != 0 { 72 return n, err 73 } 74 75 if !errors.Is(err, io.EOF) { 76 return n, err 77 } 78 79 ws.Frame = nil 80 } 81 82 } 83 84 func (ws *Conn) NextFrameReader(handle func(*Header, io.ReadCloser) error) error { 85 ws.rio.Lock() 86 defer ws.rio.Unlock() 87 88 if ws.Frame != nil { 89 _ = ws.Frame.Close() 90 ws.Frame = nil 91 } 92 93 h, r, err := ws.nextFrameReader() 94 if err != nil { 95 return err 96 } 97 defer r.Close() 98 99 if err := handle(h, r); err != nil { 100 return err 101 } 102 103 return nil 104 } 105 106 func (ws *Conn) nextFrameReader() (*Header, io.ReadCloser, error) { 107 for { 108 header, err := readFrameHeader(netapi.NewReader(ws.RawConn), ws.readHeaderBuf[:]) 109 if err != nil { 110 return nil, nil, err 111 } 112 113 frame := &frameReader{ 114 masked: header.masked, 115 maskKey: header.maskKey, 116 reader: io.LimitReader(ws.RawConn, header.payloadLength), 117 } 118 119 frameReader, err := ws.handleFrame(&header, frame) 120 if err != nil { 121 return nil, nil, err 122 } 123 124 if frameReader != nil { 125 return &header, frameReader, nil 126 } 127 } 128 } 129 130 // Write implements the io.Writer interface: 131 // it writes data as a frame to the WebSocket connection. 132 func (ws *Conn) Write(msg []byte) (n int, err error) { return ws.WriteMsg(msg, ws.PayloadType) } 133 134 func (ws *Conn) WriteMsg(msg []byte, payloadType opcode) (int, error) { 135 136 frameHeader := Header{ 137 fin: true, 138 opcode: payloadType, 139 masked: !ws.IsServer, 140 payloadLength: int64(len(msg)), 141 } 142 143 if frameHeader.masked { 144 _ = binary.Read(rand.Reader, binary.BigEndian, &frameHeader.maskKey) 145 } 146 147 buf := pool.GetBytesWriter(pool.DefaultSize + len(msg)) 148 defer buf.Free() 149 150 if err := writeFrameHeader(frameHeader, buf, ws.writeHeaderBuf[:]); err != nil { 151 return 0, err 152 } 153 154 headerLength := buf.Len() 155 156 _, _ = buf.Write(msg) 157 158 if frameHeader.masked { 159 mask(frameHeader.maskKey, buf.Bytes()[headerLength:]) 160 } 161 162 ws.wio.Lock() 163 n, err := ws.RawConn.Write(buf.Bytes()) 164 ws.wio.Unlock() 165 if err != nil { 166 return n, err 167 } 168 169 return int(frameHeader.payloadLength), nil 170 } 171 172 func (ws *Conn) handleFrame(header *Header, frame io.ReadCloser) (io.ReadCloser, error) { 173 if ws.IsServer && !header.masked { 174 // client --> server 175 // The client MUST mask all frames sent to the server. 176 ws.WriteClose(closeStatusProtocolError) 177 return nil, io.EOF 178 } else if !ws.IsServer && header.masked { 179 // server --> client 180 // The server MUST NOT mask all frames. 181 ws.WriteClose(closeStatusProtocolError) 182 return nil, io.EOF 183 } 184 185 switch header.opcode { 186 case opContinuation: 187 header.opcode = ws.LastPayloadType 188 case opText, opBinary: 189 ws.LastPayloadType = header.opcode 190 case opClose: 191 ws.Close() 192 return nil, io.EOF 193 case opPing, opPong: 194 b := make([]byte, maxControlFramePayloadLength) 195 n, err := io.ReadFull(frame, b) 196 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { 197 return nil, err 198 } 199 _ = frame.Close() 200 if header.opcode == opPing { 201 if _, err := ws.WritePong(b[:n]); err != nil { 202 return nil, err 203 } 204 } 205 return nil, nil 206 } 207 return frame, nil 208 } 209 210 func (ws *Conn) WriteClose(status int) (err error) { 211 _, err = ws.WriteMsg(binary.BigEndian.AppendUint16(nil, uint16(status)), opClose) 212 return err 213 } 214 215 func (ws *Conn) WritePong(msg []byte) (n int, err error) { return ws.WriteMsg(msg, opPong) } 216 217 // Close implements the io.Closer interface. 218 func (ws *Conn) Close() error { 219 return ws.RawConn.Close() 220 } 221 222 func (ws *Conn) LocalAddr() net.Addr { return ws.RawConn.LocalAddr() } 223 func (ws *Conn) RemoteAddr() net.Addr { return ws.RawConn.RemoteAddr() } 224 func (ws *Conn) SetDeadline(t time.Time) error { return ws.RawConn.SetDeadline(t) } 225 func (ws *Conn) SetReadDeadline(t time.Time) error { return ws.RawConn.SetReadDeadline(t) } 226 func (ws *Conn) SetWriteDeadline(t time.Time) error { return ws.RawConn.SetWriteDeadline(t) } 227 228 // A frameReader is a reader for hybi frame. 229 type frameReader struct { 230 reader io.Reader 231 232 masked bool 233 maskKey uint32 234 } 235 236 func (frame *frameReader) Read(msg []byte) (n int, err error) { 237 n, err = frame.reader.Read(msg) 238 if frame.masked { 239 frame.maskKey = mask(frame.maskKey, msg[:n]) 240 } 241 return n, err 242 } 243 244 func (f *frameReader) Close() error { 245 _, err := relay.Copy(io.Discard, f.reader) 246 return err 247 }