github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/http_stream.go (about)

     1  package http3
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"github.com/daeuniverse/quic-go"
     8  )
     9  
    10  // A Stream is a HTTP/3 stream.
    11  // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
    12  type Stream quic.Stream
    13  
    14  // The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
    15  // from the QUIC stream, it writes to and reads from the HTTP stream.
    16  type stream struct {
    17  	quic.Stream
    18  
    19  	buf []byte
    20  
    21  	onFrameError          func()
    22  	bytesRemainingInFrame uint64
    23  }
    24  
    25  var _ Stream = &stream{}
    26  
    27  func newStream(str quic.Stream, onFrameError func()) *stream {
    28  	return &stream{
    29  		Stream:       str,
    30  		onFrameError: onFrameError,
    31  		buf:          make([]byte, 0, 16),
    32  	}
    33  }
    34  
    35  func (s *stream) Read(b []byte) (int, error) {
    36  	if s.bytesRemainingInFrame == 0 {
    37  	parseLoop:
    38  		for {
    39  			frame, err := parseNextFrame(s.Stream, nil)
    40  			if err != nil {
    41  				return 0, err
    42  			}
    43  			switch f := frame.(type) {
    44  			case *headersFrame:
    45  				// skip HEADERS frames
    46  				continue
    47  			case *dataFrame:
    48  				s.bytesRemainingInFrame = f.Length
    49  				break parseLoop
    50  			default:
    51  				s.onFrameError()
    52  				// parseNextFrame skips over unknown frame types
    53  				// Therefore, this condition is only entered when we parsed another known frame type.
    54  				return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
    55  			}
    56  		}
    57  	}
    58  
    59  	var n int
    60  	var err error
    61  	if s.bytesRemainingInFrame < uint64(len(b)) {
    62  		n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
    63  	} else {
    64  		n, err = s.Stream.Read(b)
    65  	}
    66  	s.bytesRemainingInFrame -= uint64(n)
    67  	return n, err
    68  }
    69  
    70  func (s *stream) hasMoreData() bool {
    71  	return s.bytesRemainingInFrame > 0
    72  }
    73  
    74  func (s *stream) Write(b []byte) (int, error) {
    75  	s.buf = s.buf[:0]
    76  	s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
    77  	if _, err := s.Stream.Write(s.buf); err != nil {
    78  		return 0, err
    79  	}
    80  	return s.Stream.Write(b)
    81  }
    82  
    83  var errTooMuchData = errors.New("peer sent too much data")
    84  
    85  type lengthLimitedStream struct {
    86  	*stream
    87  	contentLength int64
    88  	read          int64
    89  	resetStream   bool
    90  }
    91  
    92  var _ Stream = &lengthLimitedStream{}
    93  
    94  func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
    95  	return &lengthLimitedStream{
    96  		stream:        str,
    97  		contentLength: contentLength,
    98  	}
    99  }
   100  
   101  func (s *lengthLimitedStream) checkContentLengthViolation() error {
   102  	if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
   103  		if !s.resetStream {
   104  			s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
   105  			s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   106  			s.resetStream = true
   107  		}
   108  		return errTooMuchData
   109  	}
   110  	return nil
   111  }
   112  
   113  func (s *lengthLimitedStream) Read(b []byte) (int, error) {
   114  	if err := s.checkContentLengthViolation(); err != nil {
   115  		return 0, err
   116  	}
   117  	n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)])
   118  	s.read += int64(n)
   119  	if err := s.checkContentLengthViolation(); err != nil {
   120  		return n, err
   121  	}
   122  	return n, err
   123  }