github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/response_writer_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"net/http"
     7  	"time"
     8  
     9  	mockquic "github.com/daeuniverse/quic-go/internal/mocks/quic"
    10  	"github.com/daeuniverse/quic-go/internal/utils"
    11  
    12  	"github.com/quic-go/qpack"
    13  
    14  	. "github.com/onsi/ginkgo/v2"
    15  	. "github.com/onsi/gomega"
    16  	"go.uber.org/mock/gomock"
    17  )
    18  
    19  var _ = Describe("Response Writer", func() {
    20  	var (
    21  		rw     *responseWriter
    22  		strBuf *bytes.Buffer
    23  	)
    24  
    25  	BeforeEach(func() {
    26  		strBuf = &bytes.Buffer{}
    27  		str := mockquic.NewMockStream(mockCtrl)
    28  		str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
    29  		str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes()
    30  		str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes()
    31  		rw = newResponseWriter(str, nil, utils.DefaultLogger)
    32  	})
    33  
    34  	decodeHeader := func(str io.Reader) map[string][]string {
    35  		rw.Flush()
    36  		fields := make(map[string][]string)
    37  		decoder := qpack.NewDecoder(nil)
    38  
    39  		frame, err := parseNextFrame(str, nil)
    40  		Expect(err).ToNot(HaveOccurred())
    41  		Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
    42  		headersFrame := frame.(*headersFrame)
    43  		data := make([]byte, headersFrame.Length)
    44  		_, err = io.ReadFull(str, data)
    45  		Expect(err).ToNot(HaveOccurred())
    46  		hfs, err := decoder.DecodeFull(data)
    47  		Expect(err).ToNot(HaveOccurred())
    48  		for _, p := range hfs {
    49  			fields[p.Name] = append(fields[p.Name], p.Value)
    50  		}
    51  		return fields
    52  	}
    53  
    54  	getData := func(str io.Reader) []byte {
    55  		frame, err := parseNextFrame(str, nil)
    56  		Expect(err).ToNot(HaveOccurred())
    57  		Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
    58  		df := frame.(*dataFrame)
    59  		data := make([]byte, df.Length)
    60  		_, err = io.ReadFull(str, data)
    61  		Expect(err).ToNot(HaveOccurred())
    62  		return data
    63  	}
    64  
    65  	It("writes status", func() {
    66  		rw.WriteHeader(http.StatusTeapot)
    67  		fields := decodeHeader(strBuf)
    68  		Expect(fields).To(HaveLen(2))
    69  		Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
    70  		Expect(fields).To(HaveKey("date"))
    71  	})
    72  
    73  	It("writes headers", func() {
    74  		rw.Header().Add("content-length", "42")
    75  		rw.WriteHeader(http.StatusTeapot)
    76  		fields := decodeHeader(strBuf)
    77  		Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
    78  	})
    79  
    80  	It("writes multiple headers with the same name", func() {
    81  		const cookie1 = "test1=1; Max-Age=7200; path=/"
    82  		const cookie2 = "test2=2; Max-Age=7200; path=/"
    83  		rw.Header().Add("set-cookie", cookie1)
    84  		rw.Header().Add("set-cookie", cookie2)
    85  		rw.WriteHeader(http.StatusTeapot)
    86  		fields := decodeHeader(strBuf)
    87  		Expect(fields).To(HaveKey("set-cookie"))
    88  		cookies := fields["set-cookie"]
    89  		Expect(cookies).To(ContainElement(cookie1))
    90  		Expect(cookies).To(ContainElement(cookie2))
    91  	})
    92  
    93  	It("writes data", func() {
    94  		n, err := rw.Write([]byte("foobar"))
    95  		Expect(n).To(Equal(6))
    96  		Expect(err).ToNot(HaveOccurred())
    97  		// Should have written 200 on the header stream
    98  		fields := decodeHeader(strBuf)
    99  		Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
   100  		// And foobar on the data stream
   101  		Expect(getData(strBuf)).To(Equal([]byte("foobar")))
   102  	})
   103  
   104  	It("writes data after WriteHeader is called", func() {
   105  		rw.WriteHeader(http.StatusTeapot)
   106  		n, err := rw.Write([]byte("foobar"))
   107  		Expect(n).To(Equal(6))
   108  		Expect(err).ToNot(HaveOccurred())
   109  		// Should have written 418 on the header stream
   110  		fields := decodeHeader(strBuf)
   111  		Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
   112  		// And foobar on the data stream
   113  		Expect(getData(strBuf)).To(Equal([]byte("foobar")))
   114  	})
   115  
   116  	It("does not WriteHeader() twice", func() {
   117  		rw.WriteHeader(http.StatusOK)
   118  		rw.WriteHeader(http.StatusInternalServerError)
   119  		fields := decodeHeader(strBuf)
   120  		Expect(fields).To(HaveLen(2))
   121  		Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
   122  		Expect(fields).To(HaveKey("date"))
   123  	})
   124  
   125  	It("allows calling WriteHeader() several times when using the 103 status code", func() {
   126  		rw.Header().Add("Link", "</style.css>; rel=preload; as=style")
   127  		rw.Header().Add("Link", "</script.js>; rel=preload; as=script")
   128  		rw.WriteHeader(http.StatusEarlyHints)
   129  
   130  		n, err := rw.Write([]byte("foobar"))
   131  		Expect(n).To(Equal(6))
   132  		Expect(err).ToNot(HaveOccurred())
   133  
   134  		// Early Hints must have been received
   135  		fields := decodeHeader(strBuf)
   136  		Expect(fields).To(HaveLen(2))
   137  		Expect(fields).To(HaveKeyWithValue(":status", []string{"103"}))
   138  		Expect(fields).To(HaveKeyWithValue("link", []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}))
   139  
   140  		// According to the spec, headers sent in the informational response must also be included in the final response
   141  		fields = decodeHeader(strBuf)
   142  		Expect(fields).To(HaveLen(3))
   143  		Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
   144  		Expect(fields).To(HaveKey("date"))
   145  		Expect(fields).To(HaveKeyWithValue("link", []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}))
   146  
   147  		Expect(getData(strBuf)).To(Equal([]byte("foobar")))
   148  	})
   149  
   150  	It("doesn't allow writes if the status code doesn't allow a body", func() {
   151  		rw.WriteHeader(304)
   152  		n, err := rw.Write([]byte("foobar"))
   153  		Expect(n).To(BeZero())
   154  		Expect(err).To(MatchError(http.ErrBodyNotAllowed))
   155  	})
   156  
   157  	It("first call to Write sniffs if Content-Type is not set", func() {
   158  		n, err := rw.Write([]byte("<html></html>"))
   159  		Expect(n).To(Equal(13))
   160  		Expect(err).ToNot(HaveOccurred())
   161  
   162  		fields := decodeHeader(strBuf)
   163  		Expect(fields).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   164  	})
   165  
   166  	It(`is compatible with "net/http".ResponseController`, func() {
   167  		Expect(rw.SetReadDeadline(time.Now().Add(1 * time.Second))).To(BeNil())
   168  		Expect(rw.SetWriteDeadline(time.Now().Add(1 * time.Second))).To(BeNil())
   169  	})
   170  
   171  	It(`checks Content-Length header`, func() {
   172  		rw.Header().Set("Content-Length", "6")
   173  		n, err := rw.Write([]byte("foobar"))
   174  		Expect(n).To(Equal(6))
   175  		Expect(err).To(BeNil())
   176  
   177  		n, err = rw.Write([]byte("foobar"))
   178  		Expect(n).To(Equal(0))
   179  		Expect(err).To(Equal(http.ErrContentLength))
   180  	})
   181  
   182  	It(`panics when writing invalid status`, func() {
   183  		Expect(func() { rw.WriteHeader(99) }).To(Panic())
   184  		Expect(func() { rw.WriteHeader(1000) }).To(Panic())
   185  	})
   186  })