github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/reader.go (about) 1 package wsutil 2 3 import ( 4 "errors" 5 "io" 6 "io/ioutil" 7 8 "github.com/ezoic/ws" 9 ) 10 11 // ErrNoFrameAdvance means that Reader's Read() method was called without 12 // preceding NextFrame() call. 13 var ErrNoFrameAdvance = errors.New("no frame advance") 14 15 // FrameHandlerFunc handles parsed frame header and its body represented by 16 // io.Reader. 17 // 18 // Note that reader represents already unmasked body. 19 type FrameHandlerFunc func(ws.Header, io.Reader) error 20 21 // Reader is a wrapper around source io.Reader which represents WebSocket 22 // connection. It contains options for reading messages from source. 23 // 24 // Reader implements io.Reader, which Read() method reads payload of incoming 25 // WebSocket frames. It also takes care on fragmented frames and possibly 26 // intermediate control frames between them. 27 // 28 // Note that Reader's methods are not goroutine safe. 29 type Reader struct { 30 Source io.Reader 31 State ws.State 32 33 // SkipHeaderCheck disables checking header bits to be RFC6455 compliant. 34 SkipHeaderCheck bool 35 36 // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming 37 // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned. 38 CheckUTF8 bool 39 40 // TODO(ezoic): add max frame size limit here. 41 42 OnContinuation FrameHandlerFunc 43 OnIntermediate FrameHandlerFunc 44 45 opCode ws.OpCode // Used to store message op code on fragmentation. 46 frame io.Reader // Used to as frame reader. 47 raw io.LimitedReader // Used to discard frames without cipher. 48 utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true. 49 } 50 51 // NewReader creates new frame reader that reads from r keeping given state to 52 // make some protocol validity checks when it needed. 53 func NewReader(r io.Reader, s ws.State) *Reader { 54 return &Reader{ 55 Source: r, 56 State: s, 57 } 58 } 59 60 // NewClientSideReader is a helper function that calls NewReader with r and 61 // ws.StateClientSide. 62 func NewClientSideReader(r io.Reader) *Reader { 63 return NewReader(r, ws.StateClientSide) 64 } 65 66 // NewServerSideReader is a helper function that calls NewReader with r and 67 // ws.StateServerSide. 68 func NewServerSideReader(r io.Reader) *Reader { 69 return NewReader(r, ws.StateServerSide) 70 } 71 72 // Read implements io.Reader. It reads the next message payload into p. 73 // It takes care on fragmented messages. 74 // 75 // The error is io.EOF only if all of message bytes were read. 76 // If an io.EOF happens during reading some but not all the message bytes 77 // Read() returns io.ErrUnexpectedEOF. 78 // 79 // The error is ErrNoFrameAdvance if no NextFrame() call was made before 80 // reading next message bytes. 81 func (r *Reader) Read(p []byte) (n int, err error) { 82 if r.frame == nil { 83 if !r.fragmented() { 84 // Every new Read() must be preceded by NextFrame() call. 85 return 0, ErrNoFrameAdvance 86 } 87 // Read next continuation or intermediate control frame. 88 _, err := r.NextFrame() 89 if err != nil { 90 return 0, err 91 } 92 if r.frame == nil { 93 // We handled intermediate control and now got nothing to read. 94 return 0, nil 95 } 96 } 97 98 n, err = r.frame.Read(p) 99 if err != nil && err != io.EOF { 100 return 101 } 102 if err == nil && r.raw.N != 0 { 103 return 104 } 105 106 switch { 107 case r.raw.N != 0: 108 err = io.ErrUnexpectedEOF 109 110 case r.fragmented(): 111 err = nil 112 r.resetFragment() 113 114 case r.CheckUTF8 && !r.utf8.Valid(): 115 n = r.utf8.Accepted() 116 err = ErrInvalidUTF8 117 118 default: 119 r.reset() 120 err = io.EOF 121 } 122 123 return 124 } 125 126 // Discard discards current message unread bytes. 127 // It discards all frames of fragmented message. 128 func (r *Reader) Discard() (err error) { 129 for { 130 _, err = io.Copy(ioutil.Discard, &r.raw) 131 if err != nil { 132 break 133 } 134 if !r.fragmented() { 135 break 136 } 137 if _, err = r.NextFrame(); err != nil { 138 break 139 } 140 } 141 r.reset() 142 return err 143 } 144 145 // NextFrame prepares r to read next message. It returns received frame header 146 // and non-nil error on failure. 147 // 148 // Note that next NextFrame() call must be done after receiving or discarding 149 // all current message bytes. 150 func (r *Reader) NextFrame() (hdr ws.Header, err error) { 151 hdr, err = ws.ReadHeader(r.Source) 152 if err == io.EOF && r.fragmented() { 153 // If we are in fragmented state EOF means that is was totally 154 // unexpected. 155 // 156 // NOTE: This is necessary to prevent callers such that 157 // ioutil.ReadAll to receive some amount of bytes without an error. 158 // ReadAll() ignores an io.EOF error, thus caller may think that 159 // whole message fetched, but actually only part of it. 160 err = io.ErrUnexpectedEOF 161 } 162 if err == nil && !r.SkipHeaderCheck { 163 err = ws.CheckHeader(hdr, r.State) 164 } 165 if err != nil { 166 return hdr, err 167 } 168 169 // Save raw reader to use it on discarding frame without ciphering and 170 // other streaming checks. 171 r.raw = io.LimitedReader{r.Source, hdr.Length} 172 173 frame := io.Reader(&r.raw) 174 if hdr.Masked { 175 frame = NewCipherReader(frame, hdr.Mask) 176 } 177 if r.fragmented() { 178 if hdr.OpCode.IsControl() { 179 if cb := r.OnIntermediate; cb != nil { 180 err = cb(hdr, frame) 181 } 182 if err == nil { 183 // Ensure that src is empty. 184 _, err = io.Copy(ioutil.Discard, &r.raw) 185 } 186 return 187 } 188 } else { 189 r.opCode = hdr.OpCode 190 } 191 if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) { 192 r.utf8.Source = frame 193 frame = &r.utf8 194 } 195 196 // Save reader with ciphering and other streaming checks. 197 r.frame = frame 198 199 if hdr.OpCode == ws.OpContinuation { 200 if cb := r.OnContinuation; cb != nil { 201 err = cb(hdr, frame) 202 } 203 } 204 205 if hdr.Fin { 206 r.State = r.State.Clear(ws.StateFragmented) 207 } else { 208 r.State = r.State.Set(ws.StateFragmented) 209 } 210 211 return 212 } 213 214 func (r *Reader) fragmented() bool { 215 return r.State.Fragmented() 216 } 217 218 func (r *Reader) resetFragment() { 219 r.raw = io.LimitedReader{} 220 r.frame = nil 221 // Reset source of the UTF8Reader, but not the state. 222 r.utf8.Source = nil 223 } 224 225 func (r *Reader) reset() { 226 r.raw = io.LimitedReader{} 227 r.frame = nil 228 r.utf8 = UTF8Reader{} 229 r.opCode = 0 230 } 231 232 // NextReader prepares next message read from r. It returns header that 233 // describes the message and io.Reader to read message's payload. It returns 234 // non-nil error when it is not possible to read message's initial frame. 235 // 236 // Note that next NextReader() on the same r should be done after reading all 237 // bytes from previously returned io.Reader. For more performant way to discard 238 // message use Reader and its Discard() method. 239 // 240 // Note that it will not handle any "intermediate" frames, that possibly could 241 // be received between text/binary continuation frames. That is, if peer sent 242 // text/binary frame with fin flag "false", then it could send ping frame, and 243 // eventually remaining part of text/binary frame with fin "true" – with 244 // NextReader() the ping frame will be dropped without any notice. To handle 245 // this rare, but possible situation (and if you do not know exactly which 246 // frames peer could send), you could use Reader with OnIntermediate field set. 247 func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) { 248 rd := &Reader{ 249 Source: r, 250 State: s, 251 } 252 header, err := rd.NextFrame() 253 if err != nil { 254 return header, nil, err 255 } 256 return header, rd, nil 257 }