github.com/ooni/oohttp@v0.7.2/responsecontroller_test.go (about)

     1  package http_test
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	. "github.com/ooni/oohttp"
    13  )
    14  
    15  func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
    16  func testResponseControllerFlush(t *testing.T, mode testMode) {
    17  	continuec := make(chan struct{})
    18  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    19  		ctl := NewResponseController(w)
    20  		w.Write([]byte("one"))
    21  		if err := ctl.Flush(); err != nil {
    22  			t.Errorf("ctl.Flush() = %v, want nil", err)
    23  			return
    24  		}
    25  		<-continuec
    26  		w.Write([]byte("two"))
    27  	}))
    28  
    29  	res, err := cst.c.Get(cst.ts.URL)
    30  	if err != nil {
    31  		t.Fatalf("unexpected connection error: %v", err)
    32  	}
    33  	defer res.Body.Close()
    34  
    35  	buf := make([]byte, 16)
    36  	n, err := res.Body.Read(buf)
    37  	close(continuec)
    38  	if err != nil || string(buf[:n]) != "one" {
    39  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
    40  	}
    41  
    42  	got, err := io.ReadAll(res.Body)
    43  	if err != nil || string(got) != "two" {
    44  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
    45  	}
    46  }
    47  
    48  func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
    49  func testResponseControllerHijack(t *testing.T, mode testMode) {
    50  	const header = "X-Header"
    51  	const value = "set"
    52  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    53  		ctl := NewResponseController(w)
    54  		c, _, err := ctl.Hijack()
    55  		if mode == http2Mode {
    56  			if err == nil {
    57  				t.Errorf("ctl.Hijack = nil, want error")
    58  			}
    59  			w.Header().Set(header, value)
    60  			return
    61  		}
    62  		if err != nil {
    63  			t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
    64  			return
    65  		}
    66  		fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
    67  	}))
    68  	res, err := cst.c.Get(cst.ts.URL)
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  	if got, want := res.Header.Get(header), value; got != want {
    73  		t.Errorf("response header %q = %q, want %q", header, got, want)
    74  	}
    75  }
    76  
    77  func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
    78  	run(t, testResponseControllerSetPastWriteDeadline)
    79  }
    80  func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
    81  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    82  		ctl := NewResponseController(w)
    83  		w.Write([]byte("one"))
    84  		if err := ctl.Flush(); err != nil {
    85  			t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
    86  		}
    87  		if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil {
    88  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
    89  		}
    90  
    91  		w.Write([]byte("two"))
    92  		if err := ctl.Flush(); err == nil {
    93  			t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
    94  		}
    95  		// Connection errors are sticky, so resetting the deadline does not permit
    96  		// making more progress. We might want to change this in the future, but verify
    97  		// the current behavior for now. If we do change this, we'll want to make sure
    98  		// to do so only for writing the response body, not headers.
    99  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
   100  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   101  		}
   102  		w.Write([]byte("three"))
   103  		if err := ctl.Flush(); err == nil {
   104  			t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
   105  		}
   106  	}))
   107  
   108  	res, err := cst.c.Get(cst.ts.URL)
   109  	if err != nil {
   110  		t.Fatalf("unexpected connection error: %v", err)
   111  	}
   112  	defer res.Body.Close()
   113  	b, _ := io.ReadAll(res.Body)
   114  	if string(b) != "one" {
   115  		t.Errorf("unexpected body: %q", string(b))
   116  	}
   117  }
   118  
   119  func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
   120  	run(t, testResponseControllerSetFutureWriteDeadline)
   121  }
   122  func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
   123  	errc := make(chan error, 1)
   124  	startwritec := make(chan struct{})
   125  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   126  		ctl := NewResponseController(w)
   127  		w.WriteHeader(200)
   128  		if err := ctl.Flush(); err != nil {
   129  			t.Errorf("ctl.Flush() = %v, want nil", err)
   130  		}
   131  		<-startwritec // don't set the deadline until the client reads response headers
   132  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   133  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   134  		}
   135  		_, err := io.Copy(w, neverEnding('a'))
   136  		errc <- err
   137  	}))
   138  
   139  	res, err := cst.c.Get(cst.ts.URL)
   140  	close(startwritec)
   141  	if err != nil {
   142  		t.Fatalf("unexpected connection error: %v", err)
   143  	}
   144  	defer res.Body.Close()
   145  	_, err = io.Copy(io.Discard, res.Body)
   146  	if err == nil {
   147  		t.Errorf("client reading from truncated request body: got nil error, want non-nil")
   148  	}
   149  	err = <-errc // io.Copy error
   150  	if !errors.Is(err, os.ErrDeadlineExceeded) {
   151  		t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
   152  	}
   153  }
   154  
   155  func TestResponseControllerSetPastReadDeadline(t *testing.T) {
   156  	run(t, testResponseControllerSetPastReadDeadline)
   157  }
   158  func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
   159  	readc := make(chan struct{})
   160  	donec := make(chan struct{})
   161  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   162  		defer close(donec)
   163  		ctl := NewResponseController(w)
   164  		b := make([]byte, 3)
   165  		n, err := io.ReadFull(r.Body, b)
   166  		b = b[:n]
   167  		if err != nil || string(b) != "one" {
   168  			t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
   169  			return
   170  		}
   171  		if err := ctl.SetReadDeadline(time.Now()); err != nil {
   172  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   173  			return
   174  		}
   175  		b, err = io.ReadAll(r.Body)
   176  		if err == nil || string(b) != "" {
   177  			t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
   178  		}
   179  		close(readc)
   180  		// Connection errors are sticky, so resetting the deadline does not permit
   181  		// making more progress. We might want to change this in the future, but verify
   182  		// the current behavior for now.
   183  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   184  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   185  			return
   186  		}
   187  		b, err = io.ReadAll(r.Body)
   188  		if err == nil {
   189  			t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
   190  		}
   191  	}))
   192  
   193  	pr, pw := io.Pipe()
   194  	var wg sync.WaitGroup
   195  	wg.Add(1)
   196  	go func() {
   197  		defer wg.Done()
   198  		defer pw.Close()
   199  		pw.Write([]byte("one"))
   200  		select {
   201  		case <-readc:
   202  		case <-donec:
   203  			select {
   204  			case <-readc:
   205  			default:
   206  				t.Errorf("server handler unexpectedly exited without closing readc")
   207  				return
   208  			}
   209  		}
   210  		pw.Write([]byte("two"))
   211  	}()
   212  	defer wg.Wait()
   213  	res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
   214  	if err == nil {
   215  		defer res.Body.Close()
   216  	}
   217  }
   218  
   219  func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
   220  	run(t, testResponseControllerSetFutureReadDeadline)
   221  }
   222  func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
   223  	respBody := "response body"
   224  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   225  		ctl := NewResponseController(w)
   226  		if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   227  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   228  		}
   229  		_, err := io.Copy(io.Discard, req.Body)
   230  		if !errors.Is(err, os.ErrDeadlineExceeded) {
   231  			t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
   232  		}
   233  		w.Write([]byte(respBody))
   234  	}))
   235  	pr, pw := io.Pipe()
   236  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  	defer res.Body.Close()
   241  	got, err := io.ReadAll(res.Body)
   242  	if string(got) != respBody || err != nil {
   243  		t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
   244  	}
   245  	pw.Close()
   246  }
   247  
   248  type wrapWriter struct {
   249  	ResponseWriter
   250  }
   251  
   252  func (w wrapWriter) Unwrap() ResponseWriter {
   253  	return w.ResponseWriter
   254  }
   255  
   256  func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
   257  func testWrappedResponseController(t *testing.T, mode testMode) {
   258  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   259  		w = wrapWriter{w}
   260  		ctl := NewResponseController(w)
   261  		if err := ctl.Flush(); err != nil {
   262  			t.Errorf("ctl.Flush() = %v, want nil", err)
   263  		}
   264  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   265  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   266  		}
   267  		if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
   268  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   269  		}
   270  	}))
   271  	res, err := cst.c.Get(cst.ts.URL)
   272  	if err != nil {
   273  		t.Fatalf("unexpected connection error: %v", err)
   274  	}
   275  	io.Copy(io.Discard, res.Body)
   276  	defer res.Body.Close()
   277  }
   278  
   279  func TestResponseControllerEnableFullDuplex(t *testing.T) {
   280  	run(t, testResponseControllerEnableFullDuplex)
   281  }
   282  func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) {
   283  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   284  		ctl := NewResponseController(w)
   285  		if err := ctl.EnableFullDuplex(); err != nil {
   286  			// TODO: Drop test for HTTP/2 when x/net is updated to support
   287  			// EnableFullDuplex. Since HTTP/2 supports full duplex by default,
   288  			// the rest of the test is fine; it's just the EnableFullDuplex call
   289  			// that fails.
   290  			if mode != http2Mode {
   291  				t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err)
   292  			}
   293  		}
   294  		w.WriteHeader(200)
   295  		ctl.Flush()
   296  		for {
   297  			var buf [1]byte
   298  			n, err := req.Body.Read(buf[:])
   299  			if n != 1 || err != nil {
   300  				break
   301  			}
   302  			w.Write(buf[:])
   303  			ctl.Flush()
   304  		}
   305  	}))
   306  	pr, pw := io.Pipe()
   307  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   308  	if err != nil {
   309  		t.Fatal(err)
   310  	}
   311  	defer res.Body.Close()
   312  	for i := byte(0); i < 10; i++ {
   313  		if _, err := pw.Write([]byte{i}); err != nil {
   314  			t.Fatalf("Write: %v", err)
   315  		}
   316  		var buf [1]byte
   317  		if n, err := res.Body.Read(buf[:]); n != 1 || err != nil {
   318  			t.Fatalf("Read: %v, %v", n, err)
   319  		}
   320  		if buf[0] != i {
   321  			t.Fatalf("read byte %v, want %v", buf[0], i)
   322  		}
   323  	}
   324  	pw.Close()
   325  }