github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/reader.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/ioutil"
     7  
     8  	"github.com/ezoic/ws"
     9  )
    10  
    11  // ErrNoFrameAdvance means that Reader's Read() method was called without
    12  // preceding NextFrame() call.
    13  var ErrNoFrameAdvance = errors.New("no frame advance")
    14  
    15  // FrameHandlerFunc handles parsed frame header and its body represented by
    16  // io.Reader.
    17  //
    18  // Note that reader represents already unmasked body.
    19  type FrameHandlerFunc func(ws.Header, io.Reader) error
    20  
    21  // Reader is a wrapper around source io.Reader which represents WebSocket
    22  // connection. It contains options for reading messages from source.
    23  //
    24  // Reader implements io.Reader, which Read() method reads payload of incoming
    25  // WebSocket frames. It also takes care on fragmented frames and possibly
    26  // intermediate control frames between them.
    27  //
    28  // Note that Reader's methods are not goroutine safe.
    29  type Reader struct {
    30  	Source io.Reader
    31  	State  ws.State
    32  
    33  	// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
    34  	SkipHeaderCheck bool
    35  
    36  	// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
    37  	// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
    38  	CheckUTF8 bool
    39  
    40  	// TODO(ezoic): add max frame size limit here.
    41  
    42  	OnContinuation FrameHandlerFunc
    43  	OnIntermediate FrameHandlerFunc
    44  
    45  	opCode ws.OpCode        // Used to store message op code on fragmentation.
    46  	frame  io.Reader        // Used to as frame reader.
    47  	raw    io.LimitedReader // Used to discard frames without cipher.
    48  	utf8   UTF8Reader       // Used to check UTF8 sequences if CheckUTF8 is true.
    49  }
    50  
    51  // NewReader creates new frame reader that reads from r keeping given state to
    52  // make some protocol validity checks when it needed.
    53  func NewReader(r io.Reader, s ws.State) *Reader {
    54  	return &Reader{
    55  		Source: r,
    56  		State:  s,
    57  	}
    58  }
    59  
    60  // NewClientSideReader is a helper function that calls NewReader with r and
    61  // ws.StateClientSide.
    62  func NewClientSideReader(r io.Reader) *Reader {
    63  	return NewReader(r, ws.StateClientSide)
    64  }
    65  
    66  // NewServerSideReader is a helper function that calls NewReader with r and
    67  // ws.StateServerSide.
    68  func NewServerSideReader(r io.Reader) *Reader {
    69  	return NewReader(r, ws.StateServerSide)
    70  }
    71  
    72  // Read implements io.Reader. It reads the next message payload into p.
    73  // It takes care on fragmented messages.
    74  //
    75  // The error is io.EOF only if all of message bytes were read.
    76  // If an io.EOF happens during reading some but not all the message bytes
    77  // Read() returns io.ErrUnexpectedEOF.
    78  //
    79  // The error is ErrNoFrameAdvance if no NextFrame() call was made before
    80  // reading next message bytes.
    81  func (r *Reader) Read(p []byte) (n int, err error) {
    82  	if r.frame == nil {
    83  		if !r.fragmented() {
    84  			// Every new Read() must be preceded by NextFrame() call.
    85  			return 0, ErrNoFrameAdvance
    86  		}
    87  		// Read next continuation or intermediate control frame.
    88  		_, err := r.NextFrame()
    89  		if err != nil {
    90  			return 0, err
    91  		}
    92  		if r.frame == nil {
    93  			// We handled intermediate control and now got nothing to read.
    94  			return 0, nil
    95  		}
    96  	}
    97  
    98  	n, err = r.frame.Read(p)
    99  	if err != nil && err != io.EOF {
   100  		return
   101  	}
   102  	if err == nil && r.raw.N != 0 {
   103  		return
   104  	}
   105  
   106  	switch {
   107  	case r.raw.N != 0:
   108  		err = io.ErrUnexpectedEOF
   109  
   110  	case r.fragmented():
   111  		err = nil
   112  		r.resetFragment()
   113  
   114  	case r.CheckUTF8 && !r.utf8.Valid():
   115  		n = r.utf8.Accepted()
   116  		err = ErrInvalidUTF8
   117  
   118  	default:
   119  		r.reset()
   120  		err = io.EOF
   121  	}
   122  
   123  	return
   124  }
   125  
   126  // Discard discards current message unread bytes.
   127  // It discards all frames of fragmented message.
   128  func (r *Reader) Discard() (err error) {
   129  	for {
   130  		_, err = io.Copy(ioutil.Discard, &r.raw)
   131  		if err != nil {
   132  			break
   133  		}
   134  		if !r.fragmented() {
   135  			break
   136  		}
   137  		if _, err = r.NextFrame(); err != nil {
   138  			break
   139  		}
   140  	}
   141  	r.reset()
   142  	return err
   143  }
   144  
   145  // NextFrame prepares r to read next message. It returns received frame header
   146  // and non-nil error on failure.
   147  //
   148  // Note that next NextFrame() call must be done after receiving or discarding
   149  // all current message bytes.
   150  func (r *Reader) NextFrame() (hdr ws.Header, err error) {
   151  	hdr, err = ws.ReadHeader(r.Source)
   152  	if err == io.EOF && r.fragmented() {
   153  		// If we are in fragmented state EOF means that is was totally
   154  		// unexpected.
   155  		//
   156  		// NOTE: This is necessary to prevent callers such that
   157  		// ioutil.ReadAll to receive some amount of bytes without an error.
   158  		// ReadAll() ignores an io.EOF error, thus caller may think that
   159  		// whole message fetched, but actually only part of it.
   160  		err = io.ErrUnexpectedEOF
   161  	}
   162  	if err == nil && !r.SkipHeaderCheck {
   163  		err = ws.CheckHeader(hdr, r.State)
   164  	}
   165  	if err != nil {
   166  		return hdr, err
   167  	}
   168  
   169  	// Save raw reader to use it on discarding frame without ciphering and
   170  	// other streaming checks.
   171  	r.raw = io.LimitedReader{r.Source, hdr.Length}
   172  
   173  	frame := io.Reader(&r.raw)
   174  	if hdr.Masked {
   175  		frame = NewCipherReader(frame, hdr.Mask)
   176  	}
   177  	if r.fragmented() {
   178  		if hdr.OpCode.IsControl() {
   179  			if cb := r.OnIntermediate; cb != nil {
   180  				err = cb(hdr, frame)
   181  			}
   182  			if err == nil {
   183  				// Ensure that src is empty.
   184  				_, err = io.Copy(ioutil.Discard, &r.raw)
   185  			}
   186  			return
   187  		}
   188  	} else {
   189  		r.opCode = hdr.OpCode
   190  	}
   191  	if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
   192  		r.utf8.Source = frame
   193  		frame = &r.utf8
   194  	}
   195  
   196  	// Save reader with ciphering and other streaming checks.
   197  	r.frame = frame
   198  
   199  	if hdr.OpCode == ws.OpContinuation {
   200  		if cb := r.OnContinuation; cb != nil {
   201  			err = cb(hdr, frame)
   202  		}
   203  	}
   204  
   205  	if hdr.Fin {
   206  		r.State = r.State.Clear(ws.StateFragmented)
   207  	} else {
   208  		r.State = r.State.Set(ws.StateFragmented)
   209  	}
   210  
   211  	return
   212  }
   213  
   214  func (r *Reader) fragmented() bool {
   215  	return r.State.Fragmented()
   216  }
   217  
   218  func (r *Reader) resetFragment() {
   219  	r.raw = io.LimitedReader{}
   220  	r.frame = nil
   221  	// Reset source of the UTF8Reader, but not the state.
   222  	r.utf8.Source = nil
   223  }
   224  
   225  func (r *Reader) reset() {
   226  	r.raw = io.LimitedReader{}
   227  	r.frame = nil
   228  	r.utf8 = UTF8Reader{}
   229  	r.opCode = 0
   230  }
   231  
   232  // NextReader prepares next message read from r. It returns header that
   233  // describes the message and io.Reader to read message's payload. It returns
   234  // non-nil error when it is not possible to read message's initial frame.
   235  //
   236  // Note that next NextReader() on the same r should be done after reading all
   237  // bytes from previously returned io.Reader. For more performant way to discard
   238  // message use Reader and its Discard() method.
   239  //
   240  // Note that it will not handle any "intermediate" frames, that possibly could
   241  // be received between text/binary continuation frames. That is, if peer sent
   242  // text/binary frame with fin flag "false", then it could send ping frame, and
   243  // eventually remaining part of text/binary frame with fin "true" – with
   244  // NextReader() the ping frame will be dropped without any notice. To handle
   245  // this rare, but possible situation (and if you do not know exactly which
   246  // frames peer could send), you could use Reader with OnIntermediate field set.
   247  func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
   248  	rd := &Reader{
   249  		Source: r,
   250  		State:  s,
   251  	}
   252  	header, err := rd.NextFrame()
   253  	if err != nil {
   254  		return header, nil, err
   255  	}
   256  	return header, rd, nil
   257  }