golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/h2c/h2c_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package h2c
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"log"
    14  	"net"
    15  	"net/http"
    16  	"net/http/httptest"
    17  	"strings"
    18  	"testing"
    19  
    20  	"golang.org/x/net/http2"
    21  )
    22  
    23  func ExampleNewHandler() {
    24  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    25  		fmt.Fprint(w, "Hello world")
    26  	})
    27  	h2s := &http2.Server{
    28  		// ...
    29  	}
    30  	h1s := &http.Server{
    31  		Addr:    ":8080",
    32  		Handler: NewHandler(handler, h2s),
    33  	}
    34  	log.Fatal(h1s.ListenAndServe())
    35  }
    36  
    37  func TestContext(t *testing.T) {
    38  	baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
    39  
    40  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    41  		if r.ProtoMajor != 2 {
    42  			t.Errorf("Request wasn't handled by h2c.  Got ProtoMajor=%v", r.ProtoMajor)
    43  		}
    44  		if r.Context().Value("testkey") != "testvalue" {
    45  			t.Errorf("Request doesn't have expected base context: %v", r.Context())
    46  		}
    47  		fmt.Fprint(w, "Hello world")
    48  	})
    49  
    50  	h2s := &http2.Server{}
    51  	h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
    52  	h1s.Config.BaseContext = func(_ net.Listener) context.Context {
    53  		return baseCtx
    54  	}
    55  	h1s.Start()
    56  	defer h1s.Close()
    57  
    58  	client := &http.Client{
    59  		Transport: &http2.Transport{
    60  			AllowHTTP: true,
    61  			DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
    62  				return net.Dial(network, addr)
    63  			},
    64  		},
    65  	}
    66  
    67  	resp, err := client.Get(h1s.URL)
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	_, err = ioutil.ReadAll(resp.Body)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	if err := resp.Body.Close(); err != nil {
    76  		t.Fatal(err)
    77  	}
    78  }
    79  
    80  func TestPropagation(t *testing.T) {
    81  	var (
    82  		server *http.Server
    83  		// double the limit because http2 will compress header
    84  		headerSize  = 1 << 11
    85  		headerLimit = 1 << 10
    86  	)
    87  
    88  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    89  		if r.ProtoMajor != 2 {
    90  			t.Errorf("Request wasn't handled by h2c.  Got ProtoMajor=%v", r.ProtoMajor)
    91  		}
    92  		if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
    93  			t.Errorf("Request doesn't have expected http server: %v", r.Context())
    94  		}
    95  		if len(r.Header.Get("Long-Header")) != headerSize {
    96  			t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
    97  		}
    98  		fmt.Fprint(w, "Hello world")
    99  	})
   100  
   101  	h2s := &http2.Server{}
   102  	h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
   103  
   104  	server = h1s.Config
   105  	server.MaxHeaderBytes = headerLimit
   106  	server.ConnState = func(conn net.Conn, state http.ConnState) {
   107  		t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
   108  	}
   109  
   110  	h1s.Start()
   111  	defer h1s.Close()
   112  
   113  	client := &http.Client{
   114  		Transport: &http2.Transport{
   115  			AllowHTTP: true,
   116  			DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
   117  				conn, err := net.Dial(network, addr)
   118  				if conn != nil {
   119  					t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
   120  				}
   121  				return conn, err
   122  			},
   123  		},
   124  	}
   125  
   126  	req, err := http.NewRequest("GET", h1s.URL, nil)
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  
   131  	req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
   132  
   133  	_, err = client.Do(req)
   134  	if err == nil {
   135  		t.Fatal("expected server err, got nil")
   136  	}
   137  }
   138  
   139  func TestMaxBytesHandler(t *testing.T) {
   140  	const bodyLimit = 10
   141  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   142  		t.Errorf("got request, expected to be blocked by body limit")
   143  	})
   144  
   145  	h2s := &http2.Server{}
   146  	h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
   147  	h1s.Start()
   148  	defer h1s.Close()
   149  
   150  	// Wrap the body in a struct{io.Reader} to prevent it being rewound and resent.
   151  	body := "0123456789abcdef"
   152  	req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)})
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	req.Header.Set("Http2-Settings", "")
   157  	req.Header.Set("Upgrade", "h2c")
   158  	req.Header.Set("Connection", "Upgrade, HTTP2-Settings")
   159  
   160  	resp, err := h1s.Client().Do(req)
   161  	if err != nil {
   162  		t.Fatal(err)
   163  	}
   164  	defer resp.Body.Close()
   165  	_, err = ioutil.ReadAll(resp.Body)
   166  	if err != nil {
   167  		t.Fatal(err)
   168  	}
   169  	if got, want := resp.StatusCode, http.StatusInternalServerError; got != want {
   170  		t.Errorf("resp.StatusCode = %v, want %v", got, want)
   171  	}
   172  }