github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/gateway/tcp/util/reader.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/ioutil"
     7  
     8  	"github.com/gobwas/ws"
     9  )
    10  
    11  // ErrNoFrameAdvance means that Reader's Read() method was called without
    12  // preceding NextFrame() call.
    13  var ErrNoFrameAdvance = errors.New("no frame advance")
    14  
    15  // FrameHandlerFunc handles parsed frame header and its body represented by
    16  // io.Reader.
    17  //
    18  // Note that reader represents already unmasked body.
    19  type FrameHandlerFunc func(ws.Header, io.Reader) error
    20  
    21  // Reader is a wrapper around source io.Reader which represents WebSocket
    22  // connection. It contains options for reading messages from source.
    23  //
    24  // Reader implements io.Reader, which Read() method reads payload of incoming
    25  // WebSocket frames. It also takes care on fragmented frames and possibly
    26  // intermediate control frames between them.
    27  //
    28  // Note that Reader's methods are not goroutine safe.
    29  type Reader struct {
    30  	Source io.Reader
    31  	State  ws.State
    32  
    33  	// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
    34  	SkipHeaderCheck bool
    35  	OnContinuation  FrameHandlerFunc
    36  	OnIntermediate  FrameHandlerFunc
    37  
    38  	opCode ws.OpCode        // Used to store message op code on fragmentation.
    39  	frame  io.Reader        // Used to as frame reader.
    40  	raw    io.LimitedReader // Used to discard frames without cipher.
    41  }
    42  
    43  // NewReader creates new frame reader that reads from r keeping given state to
    44  // make some protocol validity checks when it needed.
    45  func NewReader(r io.Reader, s ws.State) *Reader {
    46  	return &Reader{
    47  		Source: r,
    48  		State:  s,
    49  	}
    50  }
    51  
    52  // Read implements io.Reader. It reads the next message payload into p.
    53  // It takes care on fragmented messages.
    54  //
    55  // The error is io.EOF only if all of message bytes were read.
    56  // If an io.EOF happens during reading some but not all the message bytes
    57  // Read() returns io.ErrUnexpectedEOF.
    58  //
    59  // The error is ErrNoFrameAdvance if no NextFrame() call was made before
    60  // reading next message bytes.
    61  func (r *Reader) Read(p []byte) (n int, err error) {
    62  	if r.frame == nil {
    63  		if !r.fragmented() {
    64  			// Every new Read() must be preceded by NextFrame() call.
    65  			return 0, ErrNoFrameAdvance
    66  		}
    67  		// Read next continuation or intermediate control frame.
    68  		_, err := r.NextFrame()
    69  		if err != nil {
    70  			return 0, err
    71  		}
    72  		if r.frame == nil {
    73  			// We handled intermediate control and now got nothing to read.
    74  			return 0, nil
    75  		}
    76  	}
    77  
    78  	n, err = r.frame.Read(p)
    79  	if err != nil && err != io.EOF {
    80  		return
    81  	}
    82  	if err == nil && r.raw.N != 0 {
    83  		return
    84  	}
    85  
    86  	switch {
    87  	case r.raw.N != 0:
    88  		err = io.ErrUnexpectedEOF
    89  
    90  	case r.fragmented():
    91  		err = nil
    92  		r.resetFragment()
    93  
    94  	default:
    95  		r.reset()
    96  		err = io.EOF
    97  	}
    98  
    99  	return
   100  }
   101  
   102  // Discard discards current message unread bytes.
   103  // It discards all frames of fragmented message.
   104  func (r *Reader) Discard() (err error) {
   105  	for {
   106  		_, err = io.Copy(ioutil.Discard, &r.raw)
   107  		if err != nil {
   108  			break
   109  		}
   110  		if !r.fragmented() {
   111  			break
   112  		}
   113  		if _, err = r.NextFrame(); err != nil {
   114  			break
   115  		}
   116  	}
   117  	r.reset()
   118  
   119  	return err
   120  }
   121  
   122  // NextFrame prepares r to read next message. It returns received frame header
   123  // and non-nil error on failure.
   124  //
   125  // Note that next NextFrame() call must be done after receiving or discarding
   126  // all current message bytes.
   127  func (r *Reader) NextFrame() (hdr ws.Header, err error) {
   128  	hdr, err = ws.ReadHeader(r.Source)
   129  	if err == io.EOF && r.fragmented() {
   130  		// If we are in fragmented state EOF means that is was totally
   131  		// unexpected.
   132  		//
   133  		// NOTE: This is necessary to prevent callers such that
   134  		// ioutil.ReadAll to receive some amount of bytes without an error.
   135  		// ReadAll() ignores an io.EOF error, thus caller may think that
   136  		// whole message fetched, but actually only part of it.
   137  		err = io.ErrUnexpectedEOF
   138  	}
   139  	if err == nil && !r.SkipHeaderCheck {
   140  		err = ws.CheckHeader(hdr, r.State)
   141  	}
   142  	if err != nil {
   143  		return hdr, err
   144  	}
   145  
   146  	// Save raw reader to use it on discarding frame without ciphering and
   147  	// other streaming checks.
   148  	r.raw = io.LimitedReader{R: r.Source, N: hdr.Length}
   149  
   150  	frame := io.Reader(&r.raw)
   151  	if hdr.Masked {
   152  		frame = NewCipherReader(frame, hdr.Mask)
   153  	}
   154  	if r.fragmented() {
   155  		if hdr.OpCode.IsControl() {
   156  			if cb := r.OnIntermediate; cb != nil {
   157  				err = cb(hdr, frame)
   158  			}
   159  			if err == nil {
   160  				// Ensure that src is empty.
   161  				_, err = io.Copy(ioutil.Discard, &r.raw)
   162  			}
   163  
   164  			return
   165  		}
   166  	} else {
   167  		r.opCode = hdr.OpCode
   168  	}
   169  
   170  	// Save reader with ciphering and other streaming checks.
   171  	r.frame = frame
   172  
   173  	if hdr.OpCode == ws.OpContinuation {
   174  		if cb := r.OnContinuation; cb != nil {
   175  			err = cb(hdr, frame)
   176  		}
   177  	}
   178  
   179  	if hdr.Fin {
   180  		r.State = r.State.Clear(ws.StateFragmented)
   181  	} else {
   182  		r.State = r.State.Set(ws.StateFragmented)
   183  	}
   184  
   185  	return
   186  }
   187  
   188  func (r *Reader) fragmented() bool {
   189  	return r.State.Fragmented()
   190  }
   191  
   192  func (r *Reader) resetFragment() {
   193  	r.raw = io.LimitedReader{}
   194  	r.frame = nil
   195  }
   196  
   197  func (r *Reader) reset() {
   198  	r.raw = io.LimitedReader{}
   199  	r.frame = nil
   200  	r.opCode = 0
   201  }