github.com/MerlinKodo/quic-go@v0.39.2/http3/http_stream_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  
     7  	"github.com/MerlinKodo/quic-go"
     8  	mockquic "github.com/MerlinKodo/quic-go/internal/mocks/quic"
     9  
    10  	. "github.com/onsi/ginkgo/v2"
    11  	. "github.com/onsi/gomega"
    12  	"go.uber.org/mock/gomock"
    13  )
    14  
    15  func getDataFrame(data []byte) []byte {
    16  	b := (&dataFrame{Length: uint64(len(data))}).Append(nil)
    17  	return append(b, data...)
    18  }
    19  
    20  var _ = Describe("Stream", func() {
    21  	Context("reading", func() {
    22  		var (
    23  			str           Stream
    24  			qstr          *mockquic.MockStream
    25  			buf           *bytes.Buffer
    26  			errorCbCalled bool
    27  		)
    28  
    29  		errorCb := func() { errorCbCalled = true }
    30  
    31  		BeforeEach(func() {
    32  			buf = &bytes.Buffer{}
    33  			errorCbCalled = false
    34  			qstr = mockquic.NewMockStream(mockCtrl)
    35  			qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
    36  			qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
    37  			str = newStream(qstr, errorCb)
    38  		})
    39  
    40  		It("reads DATA frames in a single run", func() {
    41  			buf.Write(getDataFrame([]byte("foobar")))
    42  			b := make([]byte, 6)
    43  			n, err := str.Read(b)
    44  			Expect(err).ToNot(HaveOccurred())
    45  			Expect(n).To(Equal(6))
    46  			Expect(b).To(Equal([]byte("foobar")))
    47  		})
    48  
    49  		It("reads DATA frames in multiple runs", func() {
    50  			buf.Write(getDataFrame([]byte("foobar")))
    51  			b := make([]byte, 3)
    52  			n, err := str.Read(b)
    53  			Expect(err).ToNot(HaveOccurred())
    54  			Expect(n).To(Equal(3))
    55  			Expect(b).To(Equal([]byte("foo")))
    56  			n, err = str.Read(b)
    57  			Expect(err).ToNot(HaveOccurred())
    58  			Expect(n).To(Equal(3))
    59  			Expect(b).To(Equal([]byte("bar")))
    60  		})
    61  
    62  		It("reads DATA frames into too large buffers", func() {
    63  			buf.Write(getDataFrame([]byte("foobar")))
    64  			b := make([]byte, 10)
    65  			n, err := str.Read(b)
    66  			Expect(err).ToNot(HaveOccurred())
    67  			Expect(n).To(Equal(6))
    68  			Expect(b[:n]).To(Equal([]byte("foobar")))
    69  		})
    70  
    71  		It("reads DATA frames into too large buffers, in multiple runs", func() {
    72  			buf.Write(getDataFrame([]byte("foobar")))
    73  			b := make([]byte, 4)
    74  			n, err := str.Read(b)
    75  			Expect(err).ToNot(HaveOccurred())
    76  			Expect(n).To(Equal(4))
    77  			Expect(b).To(Equal([]byte("foob")))
    78  			n, err = str.Read(b)
    79  			Expect(err).ToNot(HaveOccurred())
    80  			Expect(n).To(Equal(2))
    81  			Expect(b[:n]).To(Equal([]byte("ar")))
    82  		})
    83  
    84  		It("reads multiple DATA frames", func() {
    85  			buf.Write(getDataFrame([]byte("foo")))
    86  			buf.Write(getDataFrame([]byte("bar")))
    87  			b := make([]byte, 6)
    88  			n, err := str.Read(b)
    89  			Expect(err).ToNot(HaveOccurred())
    90  			Expect(n).To(Equal(3))
    91  			Expect(b[:n]).To(Equal([]byte("foo")))
    92  			n, err = str.Read(b)
    93  			Expect(err).ToNot(HaveOccurred())
    94  			Expect(n).To(Equal(3))
    95  			Expect(b[:n]).To(Equal([]byte("bar")))
    96  		})
    97  
    98  		It("skips HEADERS frames", func() {
    99  			b := getDataFrame([]byte("foo"))
   100  			b = (&headersFrame{Length: 10}).Append(b)
   101  			b = append(b, make([]byte, 10)...)
   102  			b = append(b, getDataFrame([]byte("bar"))...)
   103  			buf.Write(b)
   104  			r := make([]byte, 6)
   105  			n, err := io.ReadFull(str, r)
   106  			Expect(err).ToNot(HaveOccurred())
   107  			Expect(n).To(Equal(6))
   108  			Expect(r).To(Equal([]byte("foobar")))
   109  		})
   110  
   111  		It("errors when it can't parse the frame", func() {
   112  			buf.Write([]byte("invalid"))
   113  			_, err := str.Read([]byte{0})
   114  			Expect(err).To(HaveOccurred())
   115  		})
   116  
   117  		It("errors on unexpected frames, and calls the error callback", func() {
   118  			b := (&settingsFrame{}).Append(nil)
   119  			buf.Write(b)
   120  			_, err := str.Read([]byte{0})
   121  			Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
   122  			Expect(errorCbCalled).To(BeTrue())
   123  		})
   124  	})
   125  
   126  	Context("writing", func() {
   127  		It("writes data frames", func() {
   128  			buf := &bytes.Buffer{}
   129  			qstr := mockquic.NewMockStream(mockCtrl)
   130  			qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   131  			str := newStream(qstr, nil)
   132  			str.Write([]byte("foo"))
   133  			str.Write([]byte("foobar"))
   134  
   135  			f, err := parseNextFrame(buf, nil)
   136  			Expect(err).ToNot(HaveOccurred())
   137  			Expect(f).To(Equal(&dataFrame{Length: 3}))
   138  			b := make([]byte, 3)
   139  			_, err = io.ReadFull(buf, b)
   140  			Expect(err).ToNot(HaveOccurred())
   141  			Expect(b).To(Equal([]byte("foo")))
   142  
   143  			f, err = parseNextFrame(buf, nil)
   144  			Expect(err).ToNot(HaveOccurred())
   145  			Expect(f).To(Equal(&dataFrame{Length: 6}))
   146  			b = make([]byte, 6)
   147  			_, err = io.ReadFull(buf, b)
   148  			Expect(err).ToNot(HaveOccurred())
   149  			Expect(b).To(Equal([]byte("foobar")))
   150  		})
   151  	})
   152  })
   153  
   154  var _ = Describe("length-limited streams", func() {
   155  	var (
   156  		str  *stream
   157  		qstr *mockquic.MockStream
   158  		buf  *bytes.Buffer
   159  	)
   160  
   161  	BeforeEach(func() {
   162  		buf = &bytes.Buffer{}
   163  		qstr = mockquic.NewMockStream(mockCtrl)
   164  		qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   165  		qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   166  		str = newStream(qstr, func() { Fail("didn't expect error callback to be called") })
   167  	})
   168  
   169  	It("reads all frames", func() {
   170  		s := newLengthLimitedStream(str, 6)
   171  		buf.Write(getDataFrame([]byte("foo")))
   172  		buf.Write(getDataFrame([]byte("bar")))
   173  		data, err := io.ReadAll(s)
   174  		Expect(err).ToNot(HaveOccurred())
   175  		Expect(data).To(Equal([]byte("foobar")))
   176  	})
   177  
   178  	It("errors if more data than the maximum length is sent, in the middle of a frame", func() {
   179  		s := newLengthLimitedStream(str, 4)
   180  		buf.Write(getDataFrame([]byte("foo")))
   181  		buf.Write(getDataFrame([]byte("bar")))
   182  		qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
   183  		qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   184  		data, err := io.ReadAll(s)
   185  		Expect(err).To(MatchError(errTooMuchData))
   186  		Expect(data).To(Equal([]byte("foob")))
   187  		// check that repeated calls to Read also return the right error
   188  		n, err := s.Read([]byte{0})
   189  		Expect(n).To(BeZero())
   190  		Expect(err).To(MatchError(errTooMuchData))
   191  	})
   192  
   193  	It("errors if more data than the maximum length is sent, as an additional frame", func() {
   194  		s := newLengthLimitedStream(str, 3)
   195  		buf.Write(getDataFrame([]byte("foo")))
   196  		buf.Write(getDataFrame([]byte("bar")))
   197  		qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
   198  		qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   199  		data, err := io.ReadAll(s)
   200  		Expect(err).To(MatchError(errTooMuchData))
   201  		Expect(data).To(Equal([]byte("foo")))
   202  	})
   203  })