github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/reader.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/ioutil"
     7  
     8  	"github.com/simonmittag/ws"
     9  )
    10  
    11  // ErrNoFrameAdvance means that Reader's Read() method was called without
    12  // preceding NextFrame() call.
    13  var ErrNoFrameAdvance = errors.New("no frame advance")
    14  
    15  // ErrFrameTooLarge indicates that a message of length higher than
    16  // MaxFrameSize was being read.
    17  var ErrFrameTooLarge = errors.New("frame too large")
    18  
    19  // FrameHandlerFunc handles parsed frame header and its body represented by
    20  // io.Reader.
    21  //
    22  // Note that reader represents already unmasked body.
    23  type FrameHandlerFunc func(ws.Header, io.Reader) error
    24  
    25  // Reader is a wrapper around source io.Reader which represents WebSocket
    26  // connection. It contains options for reading messages from source.
    27  //
    28  // Reader implements io.Reader, which Read() method reads payload of incoming
    29  // WebSocket frames. It also takes care on fragmented frames and possibly
    30  // intermediate control frames between them.
    31  //
    32  // Note that Reader's methods are not goroutine safe.
    33  type Reader struct {
    34  	Source io.Reader
    35  	State  ws.State
    36  
    37  	// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
    38  	SkipHeaderCheck bool
    39  
    40  	// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
    41  	// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
    42  	CheckUTF8 bool
    43  
    44  	// Extensions is a list of negotiated extensions for reader Source.
    45  	// It is used to meet the specs and clear appropriate bits in fragment
    46  	// header RSV segment.
    47  	Extensions []RecvExtension
    48  
    49  	// MaxFrameSize controls the maximum frame size in bytes
    50  	// that can be read. A message exceeding that size will return
    51  	// a ErrFrameTooLarge to the application.
    52  	//
    53  	// Not setting this field means there is no limit.
    54  	MaxFrameSize int64
    55  
    56  	OnContinuation FrameHandlerFunc
    57  	OnIntermediate FrameHandlerFunc
    58  
    59  	opCode ws.OpCode        // Used to store message op code on fragmentation.
    60  	frame  io.Reader        // Used to as frame reader.
    61  	raw    io.LimitedReader // Used to discard frames without cipher.
    62  	utf8   UTF8Reader       // Used to check UTF8 sequences if CheckUTF8 is true.
    63  }
    64  
    65  // NewReader creates new frame reader that reads from r keeping given state to
    66  // make some protocol validity checks when it needed.
    67  func NewReader(r io.Reader, s ws.State) *Reader {
    68  	return &Reader{
    69  		Source: r,
    70  		State:  s,
    71  	}
    72  }
    73  
    74  // NewClientSideReader is a helper function that calls NewReader with r and
    75  // ws.StateClientSide.
    76  func NewClientSideReader(r io.Reader) *Reader {
    77  	return NewReader(r, ws.StateClientSide)
    78  }
    79  
    80  // NewServerSideReader is a helper function that calls NewReader with r and
    81  // ws.StateServerSide.
    82  func NewServerSideReader(r io.Reader) *Reader {
    83  	return NewReader(r, ws.StateServerSide)
    84  }
    85  
    86  // Read implements io.Reader. It reads the next message payload into p.
    87  // It takes care on fragmented messages.
    88  //
    89  // The error is io.EOF only if all of message bytes were read.
    90  // If an io.EOF happens during reading some but not all the message bytes
    91  // Read() returns io.ErrUnexpectedEOF.
    92  //
    93  // The error is ErrNoFrameAdvance if no NextFrame() call was made before
    94  // reading next message bytes.
    95  func (r *Reader) Read(p []byte) (n int, err error) {
    96  	if r.frame == nil {
    97  		if !r.fragmented() {
    98  			// Every new Read() must be preceded by NextFrame() call.
    99  			return 0, ErrNoFrameAdvance
   100  		}
   101  		// Read next continuation or intermediate control frame.
   102  		_, err := r.NextFrame()
   103  		if err != nil {
   104  			return 0, err
   105  		}
   106  		if r.frame == nil {
   107  			// We handled intermediate control and now got nothing to read.
   108  			return 0, nil
   109  		}
   110  	}
   111  
   112  	n, err = r.frame.Read(p)
   113  	if err != nil && err != io.EOF {
   114  		return
   115  	}
   116  	if err == nil && r.raw.N != 0 {
   117  		return n, nil
   118  	}
   119  
   120  	// EOF condition (either err is io.EOF or r.raw.N is zero).
   121  	switch {
   122  	case r.raw.N != 0:
   123  		err = io.ErrUnexpectedEOF
   124  
   125  	case r.fragmented():
   126  		err = nil
   127  		r.resetFragment()
   128  
   129  	case r.CheckUTF8 && !r.utf8.Valid():
   130  		// NOTE: check utf8 only when full message received, since partial
   131  		// reads may be invalid.
   132  		n = r.utf8.Accepted()
   133  		err = ErrInvalidUTF8
   134  
   135  	default:
   136  		r.reset()
   137  		err = io.EOF
   138  	}
   139  
   140  	return
   141  }
   142  
   143  // Discard discards current message unread bytes.
   144  // It discards all frames of fragmented message.
   145  func (r *Reader) Discard() (err error) {
   146  	for {
   147  		_, err = io.Copy(ioutil.Discard, &r.raw)
   148  		if err != nil {
   149  			break
   150  		}
   151  		if !r.fragmented() {
   152  			break
   153  		}
   154  		if _, err = r.NextFrame(); err != nil {
   155  			break
   156  		}
   157  	}
   158  	r.reset()
   159  	return err
   160  }
   161  
   162  // NextFrame prepares r to read next message. It returns received frame header
   163  // and non-nil error on failure.
   164  //
   165  // Note that next NextFrame() call must be done after receiving or discarding
   166  // all current message bytes.
   167  func (r *Reader) NextFrame() (hdr ws.Header, err error) {
   168  	hdr, err = ws.ReadHeader(r.Source)
   169  	if err == io.EOF && r.fragmented() {
   170  		// If we are in fragmented state EOF means that is was totally
   171  		// unexpected.
   172  		//
   173  		// NOTE: This is necessary to prevent callers such that
   174  		// ioutil.ReadAll to receive some amount of bytes without an error.
   175  		// ReadAll() ignores an io.EOF error, thus caller may think that
   176  		// whole message fetched, but actually only part of it.
   177  		err = io.ErrUnexpectedEOF
   178  	}
   179  	if err == nil && !r.SkipHeaderCheck {
   180  		err = ws.CheckHeader(hdr, r.State)
   181  	}
   182  	if err != nil {
   183  		return hdr, err
   184  	}
   185  
   186  	if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
   187  		return hdr, ErrFrameTooLarge
   188  	}
   189  
   190  	// Save raw reader to use it on discarding frame without ciphering and
   191  	// other streaming checks.
   192  	r.raw = io.LimitedReader{
   193  		R: r.Source,
   194  		N: hdr.Length,
   195  	}
   196  
   197  	frame := io.Reader(&r.raw)
   198  	if hdr.Masked {
   199  		frame = NewCipherReader(frame, hdr.Mask)
   200  	}
   201  
   202  	for _, x := range r.Extensions {
   203  		hdr, err = x.UnsetBits(hdr)
   204  		if err != nil {
   205  			return hdr, err
   206  		}
   207  	}
   208  
   209  	if r.fragmented() {
   210  		if hdr.OpCode.IsControl() {
   211  			if cb := r.OnIntermediate; cb != nil {
   212  				err = cb(hdr, frame)
   213  			}
   214  			if err == nil {
   215  				// Ensure that src is empty.
   216  				_, err = io.Copy(ioutil.Discard, &r.raw)
   217  			}
   218  			return
   219  		}
   220  	} else {
   221  		r.opCode = hdr.OpCode
   222  	}
   223  	if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
   224  		r.utf8.Source = frame
   225  		frame = &r.utf8
   226  	}
   227  
   228  	// Save reader with ciphering and other streaming checks.
   229  	r.frame = frame
   230  
   231  	if hdr.OpCode == ws.OpContinuation {
   232  		if cb := r.OnContinuation; cb != nil {
   233  			err = cb(hdr, frame)
   234  		}
   235  	}
   236  
   237  	if hdr.Fin {
   238  		r.State = r.State.Clear(ws.StateFragmented)
   239  	} else {
   240  		r.State = r.State.Set(ws.StateFragmented)
   241  	}
   242  
   243  	return
   244  }
   245  
   246  func (r *Reader) fragmented() bool {
   247  	return r.State.Fragmented()
   248  }
   249  
   250  func (r *Reader) resetFragment() {
   251  	r.raw = io.LimitedReader{}
   252  	r.frame = nil
   253  	// Reset source of the UTF8Reader, but not the state.
   254  	r.utf8.Source = nil
   255  }
   256  
   257  func (r *Reader) reset() {
   258  	r.raw = io.LimitedReader{}
   259  	r.frame = nil
   260  	r.utf8 = UTF8Reader{}
   261  	r.opCode = 0
   262  }
   263  
   264  // NextReader prepares next message read from r. It returns header that
   265  // describes the message and io.Reader to read message's payload. It returns
   266  // non-nil error when it is not possible to read message's initial frame.
   267  //
   268  // Note that next NextReader() on the same r should be done after reading all
   269  // bytes from previously returned io.Reader. For more performant way to discard
   270  // message use Reader and its Discard() method.
   271  //
   272  // Note that it will not handle any "intermediate" frames, that possibly could
   273  // be received between text/binary continuation frames. That is, if peer sent
   274  // text/binary frame with fin flag "false", then it could send ping frame, and
   275  // eventually remaining part of text/binary frame with fin "true" – with
   276  // NextReader() the ping frame will be dropped without any notice. To handle
   277  // this rare, but possible situation (and if you do not know exactly which
   278  // frames peer could send), you could use Reader with OnIntermediate field set.
   279  func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
   280  	rd := &Reader{
   281  		Source: r,
   282  		State:  s,
   283  	}
   284  	header, err := rd.NextFrame()
   285  	if err != nil {
   286  		return header, nil, err
   287  	}
   288  	return header, rd, nil
   289  }