github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/net/http/responsecontroller_test.go (about)

     1  package http_test
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	. "net/http"
     8  	"os"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
    15  func testResponseControllerFlush(t *testing.T, mode testMode) {
    16  	continuec := make(chan struct{})
    17  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    18  		ctl := NewResponseController(w)
    19  		w.Write([]byte("one"))
    20  		if err := ctl.Flush(); err != nil {
    21  			t.Errorf("ctl.Flush() = %v, want nil", err)
    22  			return
    23  		}
    24  		<-continuec
    25  		w.Write([]byte("two"))
    26  	}))
    27  
    28  	res, err := cst.c.Get(cst.ts.URL)
    29  	if err != nil {
    30  		t.Fatalf("unexpected connection error: %v", err)
    31  	}
    32  	defer res.Body.Close()
    33  
    34  	buf := make([]byte, 16)
    35  	n, err := res.Body.Read(buf)
    36  	close(continuec)
    37  	if err != nil || string(buf[:n]) != "one" {
    38  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
    39  	}
    40  
    41  	got, err := io.ReadAll(res.Body)
    42  	if err != nil || string(got) != "two" {
    43  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
    44  	}
    45  }
    46  
    47  func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
    48  func testResponseControllerHijack(t *testing.T, mode testMode) {
    49  	const header = "X-Header"
    50  	const value = "set"
    51  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    52  		ctl := NewResponseController(w)
    53  		c, _, err := ctl.Hijack()
    54  		if mode == http2Mode {
    55  			if err == nil {
    56  				t.Errorf("ctl.Hijack = nil, want error")
    57  			}
    58  			w.Header().Set(header, value)
    59  			return
    60  		}
    61  		if err != nil {
    62  			t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
    63  			return
    64  		}
    65  		fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
    66  	}))
    67  	res, err := cst.c.Get(cst.ts.URL)
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	if got, want := res.Header.Get(header), value; got != want {
    72  		t.Errorf("response header %q = %q, want %q", header, got, want)
    73  	}
    74  }
    75  
    76  func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
    77  	run(t, testResponseControllerSetPastWriteDeadline)
    78  }
    79  func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
    80  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    81  		ctl := NewResponseController(w)
    82  		w.Write([]byte("one"))
    83  		if err := ctl.Flush(); err != nil {
    84  			t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
    85  		}
    86  		if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil {
    87  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
    88  		}
    89  
    90  		w.Write([]byte("two"))
    91  		if err := ctl.Flush(); err == nil {
    92  			t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
    93  		}
    94  		// Connection errors are sticky, so resetting the deadline does not permit
    95  		// making more progress. We might want to change this in the future, but verify
    96  		// the current behavior for now. If we do change this, we'll want to make sure
    97  		// to do so only for writing the response body, not headers.
    98  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
    99  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   100  		}
   101  		w.Write([]byte("three"))
   102  		if err := ctl.Flush(); err == nil {
   103  			t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
   104  		}
   105  	}))
   106  
   107  	res, err := cst.c.Get(cst.ts.URL)
   108  	if err != nil {
   109  		t.Fatalf("unexpected connection error: %v", err)
   110  	}
   111  	defer res.Body.Close()
   112  	b, _ := io.ReadAll(res.Body)
   113  	if string(b) != "one" {
   114  		t.Errorf("unexpected body: %q", string(b))
   115  	}
   116  }
   117  
   118  func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
   119  	run(t, testResponseControllerSetFutureWriteDeadline)
   120  }
   121  func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
   122  	errc := make(chan error, 1)
   123  	startwritec := make(chan struct{})
   124  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   125  		ctl := NewResponseController(w)
   126  		w.WriteHeader(200)
   127  		if err := ctl.Flush(); err != nil {
   128  			t.Errorf("ctl.Flush() = %v, want nil", err)
   129  		}
   130  		<-startwritec // don't set the deadline until the client reads response headers
   131  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   132  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   133  		}
   134  		_, err := io.Copy(w, neverEnding('a'))
   135  		errc <- err
   136  	}))
   137  
   138  	res, err := cst.c.Get(cst.ts.URL)
   139  	close(startwritec)
   140  	if err != nil {
   141  		t.Fatalf("unexpected connection error: %v", err)
   142  	}
   143  	defer res.Body.Close()
   144  	_, err = io.Copy(io.Discard, res.Body)
   145  	if err == nil {
   146  		t.Errorf("client reading from truncated request body: got nil error, want non-nil")
   147  	}
   148  	err = <-errc // io.Copy error
   149  	if !errors.Is(err, os.ErrDeadlineExceeded) {
   150  		t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
   151  	}
   152  }
   153  
   154  func TestResponseControllerSetPastReadDeadline(t *testing.T) {
   155  	run(t, testResponseControllerSetPastReadDeadline)
   156  }
   157  func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
   158  	readc := make(chan struct{})
   159  	donec := make(chan struct{})
   160  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   161  		defer close(donec)
   162  		ctl := NewResponseController(w)
   163  		b := make([]byte, 3)
   164  		n, err := io.ReadFull(r.Body, b)
   165  		b = b[:n]
   166  		if err != nil || string(b) != "one" {
   167  			t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
   168  			return
   169  		}
   170  		if err := ctl.SetReadDeadline(time.Now()); err != nil {
   171  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   172  			return
   173  		}
   174  		b, err = io.ReadAll(r.Body)
   175  		if err == nil || string(b) != "" {
   176  			t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
   177  		}
   178  		close(readc)
   179  		// Connection errors are sticky, so resetting the deadline does not permit
   180  		// making more progress. We might want to change this in the future, but verify
   181  		// the current behavior for now.
   182  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   183  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   184  			return
   185  		}
   186  		b, err = io.ReadAll(r.Body)
   187  		if err == nil {
   188  			t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
   189  		}
   190  	}))
   191  
   192  	pr, pw := io.Pipe()
   193  	var wg sync.WaitGroup
   194  	wg.Add(1)
   195  	go func() {
   196  		defer wg.Done()
   197  		defer pw.Close()
   198  		pw.Write([]byte("one"))
   199  		select {
   200  		case <-readc:
   201  		case <-donec:
   202  			select {
   203  			case <-readc:
   204  			default:
   205  				t.Errorf("server handler unexpectedly exited without closing readc")
   206  				return
   207  			}
   208  		}
   209  		pw.Write([]byte("two"))
   210  	}()
   211  	defer wg.Wait()
   212  	res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
   213  	if err == nil {
   214  		defer res.Body.Close()
   215  	}
   216  }
   217  
   218  func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
   219  	run(t, testResponseControllerSetFutureReadDeadline)
   220  }
   221  func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
   222  	respBody := "response body"
   223  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   224  		ctl := NewResponseController(w)
   225  		if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   226  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   227  		}
   228  		_, err := io.Copy(io.Discard, req.Body)
   229  		if !errors.Is(err, os.ErrDeadlineExceeded) {
   230  			t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
   231  		}
   232  		w.Write([]byte(respBody))
   233  	}))
   234  	pr, pw := io.Pipe()
   235  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   236  	if err != nil {
   237  		t.Fatal(err)
   238  	}
   239  	defer res.Body.Close()
   240  	got, err := io.ReadAll(res.Body)
   241  	if string(got) != respBody || err != nil {
   242  		t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
   243  	}
   244  	pw.Close()
   245  }
   246  
   247  type wrapWriter struct {
   248  	ResponseWriter
   249  }
   250  
   251  func (w wrapWriter) Unwrap() ResponseWriter {
   252  	return w.ResponseWriter
   253  }
   254  
   255  func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
   256  func testWrappedResponseController(t *testing.T, mode testMode) {
   257  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   258  		w = wrapWriter{w}
   259  		ctl := NewResponseController(w)
   260  		if err := ctl.Flush(); err != nil {
   261  			t.Errorf("ctl.Flush() = %v, want nil", err)
   262  		}
   263  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   264  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   265  		}
   266  		if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
   267  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   268  		}
   269  	}))
   270  	res, err := cst.c.Get(cst.ts.URL)
   271  	if err != nil {
   272  		t.Fatalf("unexpected connection error: %v", err)
   273  	}
   274  	io.Copy(io.Discard, res.Body)
   275  	defer res.Body.Close()
   276  }
   277  
   278  func TestResponseControllerEnableFullDuplex(t *testing.T) {
   279  	run(t, testResponseControllerEnableFullDuplex)
   280  }
   281  func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) {
   282  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   283  		ctl := NewResponseController(w)
   284  		if err := ctl.EnableFullDuplex(); err != nil {
   285  			// TODO: Drop test for HTTP/2 when x/net is updated to support
   286  			// EnableFullDuplex. Since HTTP/2 supports full duplex by default,
   287  			// the rest of the test is fine; it's just the EnableFullDuplex call
   288  			// that fails.
   289  			if mode != http2Mode {
   290  				t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err)
   291  			}
   292  		}
   293  		w.WriteHeader(200)
   294  		ctl.Flush()
   295  		for {
   296  			var buf [1]byte
   297  			n, err := req.Body.Read(buf[:])
   298  			if n != 1 || err != nil {
   299  				break
   300  			}
   301  			w.Write(buf[:])
   302  			ctl.Flush()
   303  		}
   304  	}))
   305  	pr, pw := io.Pipe()
   306  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   307  	if err != nil {
   308  		t.Fatal(err)
   309  	}
   310  	defer res.Body.Close()
   311  	for i := byte(0); i < 10; i++ {
   312  		if _, err := pw.Write([]byte{i}); err != nil {
   313  			t.Fatalf("Write: %v", err)
   314  		}
   315  		var buf [1]byte
   316  		if n, err := res.Body.Read(buf[:]); n != 1 || err != nil {
   317  			t.Fatalf("Read: %v, %v", n, err)
   318  		}
   319  		if buf[0] != i {
   320  			t.Fatalf("read byte %v, want %v", buf[0], i)
   321  		}
   322  	}
   323  	pw.Close()
   324  }