github.com/quic-go/quic-go@v0.44.0/http3/state_tracking_stream_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"os"
     9  
    10  	"github.com/quic-go/quic-go"
    11  	mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
    12  
    13  	. "github.com/onsi/ginkgo/v2"
    14  	. "github.com/onsi/gomega"
    15  	"go.uber.org/mock/gomock"
    16  )
    17  
    18  var someStreamID = quic.StreamID(12)
    19  
    20  var _ = Describe("State Tracking Stream", func() {
    21  	It("recognizes when the receive side is closed", func() {
    22  		qstr := mockquic.NewMockStream(mockCtrl)
    23  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
    24  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
    25  
    26  		var (
    27  			clearer mockStreamClearer
    28  			setter  mockErrorSetter
    29  			str     = newStateTrackingStream(qstr, &clearer, &setter)
    30  		)
    31  
    32  		buf := bytes.NewBuffer([]byte("foobar"))
    33  		qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
    34  		for i := 0; i < 3; i++ {
    35  			_, err := str.Read([]byte{0})
    36  			Expect(err).ToNot(HaveOccurred())
    37  			Expect(clearer.cleared).To(BeNil())
    38  			Expect(setter.recvErrs).To(BeEmpty())
    39  			Expect(setter.sendErrs).To(BeEmpty())
    40  		}
    41  		_, err := io.ReadAll(str)
    42  		Expect(err).ToNot(HaveOccurred())
    43  		Expect(clearer.cleared).To(BeNil())
    44  		Expect(setter.recvErrs).To(HaveLen(1))
    45  		Expect(setter.recvErrs[0]).To(Equal(io.EOF))
    46  		Expect(setter.sendErrs).To(BeEmpty())
    47  	})
    48  
    49  	It("recognizes local read cancellations", func() {
    50  		qstr := mockquic.NewMockStream(mockCtrl)
    51  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
    52  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
    53  
    54  		var (
    55  			clearer mockStreamClearer
    56  			setter  mockErrorSetter
    57  			str     = newStateTrackingStream(qstr, &clearer, &setter)
    58  		)
    59  
    60  		buf := bytes.NewBuffer([]byte("foobar"))
    61  		qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
    62  		qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337))
    63  		_, err := str.Read(make([]byte, 3))
    64  		Expect(err).ToNot(HaveOccurred())
    65  		Expect(clearer.cleared).To(BeNil())
    66  		Expect(setter.recvErrs).To(BeEmpty())
    67  		Expect(setter.sendErrs).To(BeEmpty())
    68  
    69  		str.CancelRead(1337)
    70  		Expect(clearer.cleared).To(BeNil())
    71  		Expect(setter.recvErrs).To(HaveLen(1))
    72  		Expect(setter.recvErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337}))
    73  		Expect(setter.sendErrs).To(BeEmpty())
    74  	})
    75  
    76  	It("recognizes remote cancellations", func() {
    77  		qstr := mockquic.NewMockStream(mockCtrl)
    78  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
    79  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
    80  
    81  		var (
    82  			clearer mockStreamClearer
    83  			setter  mockErrorSetter
    84  			str     = newStateTrackingStream(qstr, &clearer, &setter)
    85  		)
    86  
    87  		testErr := errors.New("test error")
    88  		qstr.EXPECT().Read(gomock.Any()).Return(0, testErr)
    89  		_, err := str.Read(make([]byte, 3))
    90  		Expect(err).To(MatchError(testErr))
    91  		Expect(clearer.cleared).To(BeNil())
    92  		Expect(setter.recvErrs).To(HaveLen(1))
    93  		Expect(setter.recvErrs[0]).To(Equal(testErr))
    94  		Expect(setter.sendErrs).To(BeEmpty())
    95  	})
    96  
    97  	It("doesn't misinterpret read deadline errors", func() {
    98  		qstr := mockquic.NewMockStream(mockCtrl)
    99  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   100  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   101  
   102  		var (
   103  			clearer mockStreamClearer
   104  			setter  mockErrorSetter
   105  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   106  		)
   107  
   108  		qstr.EXPECT().Read(gomock.Any()).Return(0, os.ErrDeadlineExceeded)
   109  		_, err := str.Read(make([]byte, 3))
   110  		Expect(err).To(MatchError(os.ErrDeadlineExceeded))
   111  		Expect(clearer.cleared).To(BeNil())
   112  		Expect(setter.recvErrs).To(BeEmpty())
   113  		Expect(setter.sendErrs).To(BeEmpty())
   114  	})
   115  
   116  	It("recognizes when the send side is closed, when write errors", func() {
   117  		qstr := mockquic.NewMockStream(mockCtrl)
   118  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   119  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   120  
   121  		var (
   122  			clearer mockStreamClearer
   123  			setter  mockErrorSetter
   124  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   125  		)
   126  
   127  		testErr := errors.New("test error")
   128  		qstr.EXPECT().Write([]byte("foo")).Return(3, nil)
   129  		qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
   130  
   131  		_, err := str.Write([]byte("foo"))
   132  		Expect(err).ToNot(HaveOccurred())
   133  		Expect(clearer.cleared).To(BeNil())
   134  		Expect(setter.recvErrs).To(BeEmpty())
   135  		Expect(setter.sendErrs).To(BeEmpty())
   136  
   137  		_, err = str.Write([]byte("bar"))
   138  		Expect(err).To(MatchError(testErr))
   139  		Expect(clearer.cleared).To(BeNil())
   140  		Expect(setter.recvErrs).To(BeEmpty())
   141  		Expect(setter.sendErrs).To(HaveLen(1))
   142  		Expect(setter.sendErrs[0]).To(Equal(testErr))
   143  	})
   144  
   145  	It("recognizes when the send side is closed, when write errors", func() {
   146  		qstr := mockquic.NewMockStream(mockCtrl)
   147  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   148  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   149  
   150  		var (
   151  			clearer mockStreamClearer
   152  			setter  mockErrorSetter
   153  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   154  		)
   155  
   156  		qstr.EXPECT().Write([]byte("foo")).Return(0, os.ErrDeadlineExceeded)
   157  		Expect(clearer.cleared).To(BeNil())
   158  		Expect(setter.recvErrs).To(BeEmpty())
   159  		Expect(setter.sendErrs).To(BeEmpty())
   160  
   161  		_, err := str.Write([]byte("foo"))
   162  		Expect(err).To(MatchError(os.ErrDeadlineExceeded))
   163  		Expect(clearer.cleared).To(BeNil())
   164  		Expect(setter.recvErrs).To(BeEmpty())
   165  		Expect(setter.sendErrs).To(BeEmpty())
   166  	})
   167  
   168  	It("recognizes when the send side is closed, when CancelWrite is called", func() {
   169  		qstr := mockquic.NewMockStream(mockCtrl)
   170  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   171  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   172  
   173  		var (
   174  			clearer mockStreamClearer
   175  			setter  mockErrorSetter
   176  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   177  		)
   178  
   179  		qstr.EXPECT().Write(gomock.Any())
   180  		qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337))
   181  		_, err := str.Write([]byte("foobar"))
   182  		Expect(err).ToNot(HaveOccurred())
   183  		Expect(clearer.cleared).To(BeNil())
   184  		Expect(setter.recvErrs).To(BeEmpty())
   185  		Expect(setter.sendErrs).To(BeEmpty())
   186  
   187  		str.CancelWrite(1337)
   188  		Expect(clearer.cleared).To(BeNil())
   189  		Expect(setter.recvErrs).To(BeEmpty())
   190  		Expect(setter.sendErrs).To(HaveLen(1))
   191  		Expect(setter.sendErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337}))
   192  	})
   193  
   194  	It("recognizes when the send side is closed, when the stream context is canceled", func() {
   195  		qstr := mockquic.NewMockStream(mockCtrl)
   196  		qstr.EXPECT().StreamID().AnyTimes()
   197  		ctx, cancel := context.WithCancelCause(context.Background())
   198  		qstr.EXPECT().Context().Return(ctx).AnyTimes()
   199  
   200  		var (
   201  			clearer mockStreamClearer
   202  			setter  = mockErrorSetter{
   203  				sendSent: make(chan struct{}),
   204  			}
   205  		)
   206  
   207  		_ = newStateTrackingStream(qstr, &clearer, &setter)
   208  		Expect(clearer.cleared).To(BeNil())
   209  		Expect(setter.recvErrs).To(BeEmpty())
   210  		Expect(setter.sendErrs).To(BeEmpty())
   211  
   212  		testErr := errors.New("test error")
   213  		cancel(testErr)
   214  		Eventually(setter.sendSent).Should(BeClosed())
   215  		Expect(clearer.cleared).To(BeNil())
   216  		Expect(setter.recvErrs).To(BeEmpty())
   217  		Expect(setter.sendErrs).To(HaveLen(1))
   218  		Expect(setter.sendErrs[0]).To(Equal(testErr))
   219  	})
   220  
   221  	It("clears the stream when receive is closed followed by send is closed", func() {
   222  		qstr := mockquic.NewMockStream(mockCtrl)
   223  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   224  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   225  
   226  		var (
   227  			clearer mockStreamClearer
   228  			setter  mockErrorSetter
   229  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   230  		)
   231  
   232  		buf := bytes.NewBuffer([]byte("foobar"))
   233  		qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   234  		_, err := io.ReadAll(str)
   235  		Expect(err).ToNot(HaveOccurred())
   236  
   237  		Expect(clearer.cleared).To(BeNil())
   238  		Expect(setter.recvErrs).To(HaveLen(1))
   239  		Expect(setter.recvErrs[0]).To(Equal(io.EOF))
   240  
   241  		testErr := errors.New("test error")
   242  		qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
   243  
   244  		_, err = str.Write([]byte("bar"))
   245  		Expect(err).To(MatchError(testErr))
   246  		Expect(setter.sendErrs).To(HaveLen(1))
   247  		Expect(setter.sendErrs[0]).To(Equal(testErr))
   248  
   249  		Expect(clearer.cleared).To(Equal(&someStreamID))
   250  	})
   251  
   252  	It("clears the stream when send is closed followed by receive is closed", func() {
   253  		qstr := mockquic.NewMockStream(mockCtrl)
   254  		qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID)
   255  		qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
   256  
   257  		var (
   258  			clearer mockStreamClearer
   259  			setter  mockErrorSetter
   260  			str     = newStateTrackingStream(qstr, &clearer, &setter)
   261  		)
   262  
   263  		testErr := errors.New("test error")
   264  		qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
   265  
   266  		_, err := str.Write([]byte("bar"))
   267  		Expect(err).To(MatchError(testErr))
   268  		Expect(clearer.cleared).To(BeNil())
   269  		Expect(setter.sendErrs).To(HaveLen(1))
   270  		Expect(setter.sendErrs[0]).To(Equal(testErr))
   271  
   272  		buf := bytes.NewBuffer([]byte("foobar"))
   273  		qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   274  
   275  		_, err = io.ReadAll(str)
   276  		Expect(err).ToNot(HaveOccurred())
   277  		Expect(setter.recvErrs).To(HaveLen(1))
   278  		Expect(setter.recvErrs[0]).To(Equal(io.EOF))
   279  
   280  		Expect(clearer.cleared).To(Equal(&someStreamID))
   281  	})
   282  })
   283  
   284  type mockStreamClearer struct {
   285  	cleared *quic.StreamID
   286  }
   287  
   288  func (s *mockStreamClearer) clearStream(id quic.StreamID) {
   289  	s.cleared = &id
   290  }
   291  
   292  type mockErrorSetter struct {
   293  	sendErrs []error
   294  	recvErrs []error
   295  
   296  	sendSent chan struct{}
   297  }
   298  
   299  func (e *mockErrorSetter) SetSendError(err error) {
   300  	e.sendErrs = append(e.sendErrs, err)
   301  
   302  	if e.sendSent != nil {
   303  		close(e.sendSent)
   304  	}
   305  }
   306  
   307  func (e *mockErrorSetter) SetReceiveError(err error) {
   308  	e.recvErrs = append(e.recvErrs, err)
   309  }