github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/server_test.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/http/httputil"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync/atomic"
    19  	"testing"
    20  	_ "unsafe" // for go:linkname
    21  
    22  	"github.com/gobwas/httphead"
    23  )
    24  
    25  // TODO(gobwas): upgradeGenericCase with methods like configureUpgrader,
    26  // configureHTTPUpgrader.
    27  type upgradeCase struct {
    28  	label string
    29  
    30  	protocol  func(string) bool
    31  	negotiate func(httphead.Option) (httphead.Option, error)
    32  	onRequest func(u []byte) error
    33  	onHost    func(h []byte) error
    34  	onHeader  func(k, v []byte) error
    35  
    36  	nonce        []byte
    37  	removeSecKey bool
    38  	badSecKey    bool
    39  	secKeyHeader string
    40  
    41  	req *http.Request
    42  	res *http.Response
    43  	hs  Handshake
    44  	err error
    45  }
    46  
    47  var upgradeCases = []upgradeCase{
    48  	{
    49  		label: "base",
    50  		nonce: mustMakeNonce(),
    51  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    52  			headerUpgrade:    []string{"websocket"},
    53  			headerConnection: []string{"Upgrade"},
    54  			headerSecVersion: []string{"13"},
    55  		}),
    56  		res: mustMakeResponse(101, http.Header{
    57  			headerUpgrade:    []string{"websocket"},
    58  			headerConnection: []string{"Upgrade"},
    59  		}),
    60  	},
    61  	{
    62  		label:        "base_canonical",
    63  		nonce:        mustMakeNonce(),
    64  		secKeyHeader: headerSecKeyCanonical,
    65  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    66  			headerUpgrade:             []string{"websocket"},
    67  			headerConnection:          []string{"Upgrade"},
    68  			headerSecVersionCanonical: []string{"13"},
    69  		}),
    70  		res: mustMakeResponse(101, http.Header{
    71  			headerUpgrade:    []string{"websocket"},
    72  			headerConnection: []string{"Upgrade"},
    73  		}),
    74  	},
    75  	{
    76  		label:        "lowercase_headers",
    77  		nonce:        mustMakeNonce(),
    78  		secKeyHeader: strings.ToLower(headerSecKey),
    79  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    80  			strings.ToLower(headerUpgrade):    []string{"websocket"},
    81  			strings.ToLower(headerConnection): []string{"Upgrade"},
    82  			strings.ToLower(headerSecVersion): []string{"13"},
    83  		}),
    84  		res: mustMakeResponse(101, http.Header{
    85  			headerUpgrade:    []string{"websocket"},
    86  			headerConnection: []string{"Upgrade"},
    87  		}),
    88  	},
    89  	{
    90  		label:    "uppercase",
    91  		protocol: func(sub string) bool { return true },
    92  		nonce:    mustMakeNonce(),
    93  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    94  			headerUpgrade:    []string{"WEBSOCKET"},
    95  			headerConnection: []string{"UPGRADE"},
    96  			headerSecVersion: []string{"13"},
    97  		}),
    98  		res: mustMakeResponse(101, http.Header{
    99  			headerUpgrade:    []string{"websocket"},
   100  			headerConnection: []string{"Upgrade"},
   101  		}),
   102  	},
   103  	{
   104  		label:    "subproto",
   105  		protocol: SelectFromSlice([]string{"b", "d"}),
   106  		nonce:    mustMakeNonce(),
   107  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   108  			headerUpgrade:     []string{"websocket"},
   109  			headerConnection:  []string{"Upgrade"},
   110  			headerSecVersion:  []string{"13"},
   111  			headerSecProtocol: []string{"a", "b", "c", "d"},
   112  		}),
   113  		res: mustMakeResponse(101, http.Header{
   114  			headerUpgrade:     []string{"websocket"},
   115  			headerConnection:  []string{"Upgrade"},
   116  			headerSecProtocol: []string{"b"},
   117  		}),
   118  		hs: Handshake{Protocol: "b"},
   119  	},
   120  	{
   121  		label:        "subproto_lowercase_headers",
   122  		protocol:     SelectFromSlice([]string{"b", "d"}),
   123  		nonce:        mustMakeNonce(),
   124  		secKeyHeader: strings.ToLower(headerSecKey),
   125  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   126  			strings.ToLower(headerUpgrade):     []string{"websocket"},
   127  			strings.ToLower(headerConnection):  []string{"Upgrade"},
   128  			strings.ToLower(headerSecVersion):  []string{"13"},
   129  			strings.ToLower(headerSecProtocol): []string{"a", "b", "c", "d"},
   130  		}),
   131  		res: mustMakeResponse(101, http.Header{
   132  			headerUpgrade:     []string{"websocket"},
   133  			headerConnection:  []string{"Upgrade"},
   134  			headerSecProtocol: []string{"b"},
   135  		}),
   136  		hs: Handshake{Protocol: "b"},
   137  	},
   138  	{
   139  		label:    "subproto_comma",
   140  		protocol: SelectFromSlice([]string{"b", "d"}),
   141  		nonce:    mustMakeNonce(),
   142  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   143  			headerUpgrade:     []string{"websocket"},
   144  			headerConnection:  []string{"Upgrade"},
   145  			headerSecVersion:  []string{"13"},
   146  			headerSecProtocol: []string{"a, b, c, d"},
   147  		}),
   148  		res: mustMakeResponse(101, http.Header{
   149  			headerUpgrade:     []string{"websocket"},
   150  			headerConnection:  []string{"Upgrade"},
   151  			headerSecProtocol: []string{"b"},
   152  		}),
   153  		hs: Handshake{Protocol: "b"},
   154  	},
   155  	{
   156  		negotiate: func(opt httphead.Option) (ret httphead.Option, err error) {
   157  			switch string(opt.Name) {
   158  			case "b", "d":
   159  				return opt.Clone(), nil
   160  			default:
   161  				return ret, nil
   162  			}
   163  		},
   164  		nonce: mustMakeNonce(),
   165  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   166  			headerUpgrade:       []string{"websocket"},
   167  			headerConnection:    []string{"Upgrade"},
   168  			headerSecVersion:    []string{"13"},
   169  			headerSecExtensions: []string{"a;foo=1", "b;bar=2", "c", "d;baz=3"},
   170  		}),
   171  		res: mustMakeResponse(101, http.Header{
   172  			headerUpgrade:       []string{"websocket"},
   173  			headerConnection:    []string{"Upgrade"},
   174  			headerSecExtensions: []string{"b;bar=2,d;baz=3"},
   175  		}),
   176  		hs: Handshake{
   177  			Extensions: []httphead.Option{
   178  				httphead.NewOption("b", map[string]string{
   179  					"bar": "2",
   180  				}),
   181  				httphead.NewOption("d", map[string]string{
   182  					"baz": "3",
   183  				}),
   184  			},
   185  		},
   186  	},
   187  
   188  	// Error cases.
   189  	// ------------
   190  
   191  	{
   192  		label: "bad_http_method",
   193  		nonce: mustMakeNonce(),
   194  		req: mustMakeRequest("POST", "ws://example.org", http.Header{
   195  			headerUpgrade:    []string{"websocket"},
   196  			headerConnection: []string{"Upgrade"},
   197  			headerSecVersion: []string{"13"},
   198  		}),
   199  		res: mustMakeErrResponse(405, ErrHandshakeBadMethod, nil),
   200  		err: ErrHandshakeBadMethod,
   201  	},
   202  	{
   203  		label: "bad_http_proto",
   204  		nonce: mustMakeNonce(),
   205  		req: setProto(1, 0, mustMakeRequest("GET", "ws://example.org", http.Header{
   206  			headerUpgrade:    []string{"websocket"},
   207  			headerConnection: []string{"Upgrade"},
   208  			headerSecVersion: []string{"13"},
   209  		})),
   210  		res: mustMakeErrResponse(505, ErrHandshakeBadProtocol, nil),
   211  		err: ErrHandshakeBadProtocol,
   212  	},
   213  	{
   214  		label: "bad_host",
   215  		nonce: mustMakeNonce(),
   216  		req: withoutHeader("Host", mustMakeRequest("GET", "ws://example.org", http.Header{
   217  			headerUpgrade:    []string{"websocket"},
   218  			headerConnection: []string{"Upgrade"},
   219  			headerSecVersion: []string{"13"},
   220  		})),
   221  		res: mustMakeErrResponse(400, ErrHandshakeBadHost, nil),
   222  		err: ErrHandshakeBadHost,
   223  	},
   224  	{
   225  		label: "bad_upgrade",
   226  		nonce: mustMakeNonce(),
   227  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   228  			headerConnection: []string{"Upgrade"},
   229  			headerSecVersion: []string{"13"},
   230  		}),
   231  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   232  		err: ErrHandshakeBadUpgrade,
   233  	},
   234  	{
   235  		label: "bad_upgrade",
   236  		nonce: mustMakeNonce(),
   237  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   238  			"X-Custom-Header": []string{"value"},
   239  			headerConnection:  []string{"Upgrade"},
   240  			headerSecVersion:  []string{"13"},
   241  		}),
   242  
   243  		onRequest: func([]byte) error { return nil },
   244  		onHost:    func([]byte) error { return nil },
   245  		onHeader:  func(k, v []byte) error { return nil },
   246  
   247  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   248  		err: ErrHandshakeBadUpgrade,
   249  	},
   250  	{
   251  		label: "bad_upgrade",
   252  		nonce: mustMakeNonce(),
   253  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   254  			headerUpgrade:    []string{"not-websocket"},
   255  			headerConnection: []string{"Upgrade"},
   256  			headerSecVersion: []string{"13"},
   257  		}),
   258  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   259  		err: ErrHandshakeBadUpgrade,
   260  	},
   261  	{
   262  		label: "bad_connection",
   263  		nonce: mustMakeNonce(),
   264  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   265  			headerUpgrade:    []string{"websocket"},
   266  			headerSecVersion: []string{"13"},
   267  		}),
   268  		res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil),
   269  		err: ErrHandshakeBadConnection,
   270  	},
   271  	{
   272  		label: "bad_connection",
   273  		nonce: mustMakeNonce(),
   274  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   275  			headerUpgrade:    []string{"websocket"},
   276  			headerConnection: []string{"not-upgrade"},
   277  			headerSecVersion: []string{"13"},
   278  		}),
   279  		res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil),
   280  		err: ErrHandshakeBadConnection,
   281  	},
   282  	{
   283  		label: "bad_sec_version_x",
   284  		nonce: mustMakeNonce(),
   285  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   286  			headerUpgrade:    []string{"websocket"},
   287  			headerConnection: []string{"Upgrade"},
   288  		}),
   289  		res: mustMakeErrResponse(400, ErrHandshakeBadSecVersion, nil),
   290  		err: ErrHandshakeBadSecVersion,
   291  	},
   292  	{
   293  		label: "bad_sec_version",
   294  		nonce: mustMakeNonce(),
   295  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   296  			headerUpgrade:    []string{"websocket"},
   297  			headerConnection: []string{"upgrade"},
   298  			headerSecVersion: []string{"15"},
   299  		}),
   300  		res: mustMakeErrResponse(426, ErrHandshakeBadSecVersion, http.Header{
   301  			headerSecVersion: []string{"13"},
   302  		}),
   303  		err: ErrHandshakeUpgradeRequired,
   304  	},
   305  	{
   306  		label:        "bad_sec_key",
   307  		nonce:        mustMakeNonce(),
   308  		removeSecKey: true,
   309  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   310  			headerUpgrade:    []string{"websocket"},
   311  			headerConnection: []string{"Upgrade"},
   312  			headerSecVersion: []string{"13"},
   313  		}),
   314  		res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil),
   315  		err: ErrHandshakeBadSecKey,
   316  	},
   317  	{
   318  		label:     "bad_sec_key",
   319  		nonce:     mustMakeNonce(),
   320  		badSecKey: true,
   321  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   322  			headerUpgrade:    []string{"websocket"},
   323  			headerConnection: []string{"Upgrade"},
   324  			headerSecVersion: []string{"13"},
   325  		}),
   326  		res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil),
   327  		err: ErrHandshakeBadSecKey,
   328  	},
   329  	{
   330  		label: "bad_ws_extension",
   331  		nonce: mustMakeNonce(),
   332  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   333  			headerUpgrade:       []string{"websocket"},
   334  			headerConnection:    []string{"Upgrade"},
   335  			headerSecVersion:    []string{"13"},
   336  			headerSecExtensions: []string{"=["},
   337  		}),
   338  		negotiate: func(opt httphead.Option) (ret httphead.Option, err error) {
   339  			return ret, nil
   340  		},
   341  		res: mustMakeErrResponse(400, ErrMalformedRequest, nil),
   342  		err: ErrMalformedRequest,
   343  	},
   344  	{
   345  		label: "bad_subprotocol",
   346  		nonce: mustMakeNonce(),
   347  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   348  			headerUpgrade:     []string{"websocket"},
   349  			headerConnection:  []string{"Upgrade"},
   350  			headerSecVersion:  []string{"13"},
   351  			headerSecProtocol: []string{"=["},
   352  		}),
   353  		protocol: func(string) bool {
   354  			return false
   355  		},
   356  		res: mustMakeErrResponse(400, ErrMalformedRequest, nil),
   357  		err: ErrMalformedRequest,
   358  	},
   359  }
   360  
   361  func TestHTTPUpgrader(t *testing.T) {
   362  	for _, test := range upgradeCases {
   363  		t.Run(test.label, func(t *testing.T) {
   364  			if !test.removeSecKey {
   365  				nonce := test.nonce
   366  				if test.badSecKey {
   367  					nonce = nonce[:nonceSize-1]
   368  				}
   369  				if test.secKeyHeader == "" {
   370  					test.secKeyHeader = headerSecKey
   371  				}
   372  				test.req.Header[test.secKeyHeader] = []string{string(nonce)}
   373  			}
   374  			if test.err == nil {
   375  				test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
   376  			}
   377  
   378  			// Need to emulate http server read request for truth test.
   379  			//
   380  			// We use dumpRequest here because test.req.Write is always send
   381  			// http/1.1 proto version, that does not fits all our testing
   382  			// cases.
   383  			reqBytes := dumpRequest(test.req)
   384  			req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqBytes)))
   385  			if err != nil {
   386  				t.Fatal(err)
   387  			}
   388  
   389  			res := newRecorder()
   390  
   391  			u := HTTPUpgrader{
   392  				Protocol:  test.protocol,
   393  				Negotiate: test.negotiate,
   394  			}
   395  			_, _, hs, err := u.Upgrade(req, res)
   396  			if test.err != err {
   397  				t.Errorf(
   398  					"expected error to be '%v', got '%v';\non request:\n====\n%s\n====",
   399  					test.err, err, dumpRequest(req),
   400  				)
   401  				return
   402  			}
   403  
   404  			actRespBts := sortHeaders(res.Bytes())
   405  			expRespBts := sortHeaders(dumpResponse(test.res))
   406  			if !bytes.Equal(actRespBts, expRespBts) {
   407  				t.Errorf(
   408  					"unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====",
   409  					actRespBts, expRespBts, dumpRequest(test.req),
   410  				)
   411  				return
   412  			}
   413  
   414  			if act, exp := hs.Protocol, test.hs.Protocol; act != exp {
   415  				t.Errorf("handshake protocol is %q want %q", act, exp)
   416  			}
   417  			if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp {
   418  				t.Errorf("handshake got %d extensions; want %d", act, exp)
   419  			} else {
   420  				for i := 0; i < act; i++ {
   421  					if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) {
   422  						t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp)
   423  					}
   424  				}
   425  			}
   426  		})
   427  	}
   428  }
   429  
   430  func TestUpgrader(t *testing.T) {
   431  	for _, test := range upgradeCases {
   432  		t.Run(test.label, func(t *testing.T) {
   433  			if !test.removeSecKey {
   434  				nonce := test.nonce[:]
   435  				if test.badSecKey {
   436  					nonce = nonce[:nonceSize-1]
   437  				}
   438  				test.req.Header[headerSecKey] = []string{string(nonce)}
   439  			}
   440  			if test.err == nil {
   441  				test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
   442  			}
   443  
   444  			u := Upgrader{
   445  				Protocol: func(p []byte) bool {
   446  					return test.protocol(string(p))
   447  				},
   448  				Negotiate: test.negotiate,
   449  				OnHeader:  test.onHeader,
   450  				OnRequest: test.onRequest,
   451  			}
   452  
   453  			// We use dumpRequest here because test.req.Write is always send
   454  			// http/1.1 proto version, that does not fits all our testing
   455  			// cases.
   456  			reqBytes := dumpRequest(test.req)
   457  			conn := bytes.NewBuffer(reqBytes)
   458  
   459  			hs, err := u.Upgrade(conn)
   460  			if test.err != err {
   461  
   462  				t.Errorf("expected error to be '%v', got '%v'", test.err, err)
   463  				return
   464  			}
   465  
   466  			actRespBts := sortHeaders(conn.Bytes())
   467  			expRespBts := sortHeaders(dumpResponse(test.res))
   468  			if !bytes.Equal(actRespBts, expRespBts) {
   469  				t.Errorf(
   470  					"unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====",
   471  					actRespBts, expRespBts, dumpRequest(test.req),
   472  				)
   473  				return
   474  			}
   475  
   476  			if act, exp := hs.Protocol, test.hs.Protocol; act != exp {
   477  				t.Errorf("handshake protocol is %q want %q", act, exp)
   478  			}
   479  			if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp {
   480  				t.Errorf("handshake got %d extensions; want %d", act, exp)
   481  			} else {
   482  				for i := 0; i < act; i++ {
   483  					if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) {
   484  						t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp)
   485  					}
   486  				}
   487  			}
   488  		})
   489  	}
   490  }
   491  
   492  func BenchmarkHTTPUpgrader(b *testing.B) {
   493  	for _, bench := range upgradeCases {
   494  		bench.req.Header.Set(headerSecKey, string(bench.nonce[:]))
   495  
   496  		u := HTTPUpgrader{
   497  			Protocol:  bench.protocol,
   498  			Negotiate: bench.negotiate,
   499  		}
   500  
   501  		b.Run(bench.label, func(b *testing.B) {
   502  			res := make([]http.ResponseWriter, b.N)
   503  			for i := 0; i < b.N; i++ {
   504  				res[i] = newRecorder()
   505  			}
   506  
   507  			i := new(int64)
   508  
   509  			b.ResetTimer()
   510  			b.ReportAllocs()
   511  			b.RunParallel(func(pb *testing.PB) {
   512  				for pb.Next() {
   513  					w := res[atomic.AddInt64(i, 1)-1]
   514  					u.Upgrade(bench.req, w)
   515  				}
   516  			})
   517  		})
   518  	}
   519  }
   520  
   521  func BenchmarkUpgrader(b *testing.B) {
   522  	for _, bench := range upgradeCases {
   523  		bench.req.Header.Set(headerSecKey, string(bench.nonce[:]))
   524  
   525  		u := Upgrader{
   526  			Protocol: func(p []byte) bool {
   527  				return bench.protocol(btsToString(p))
   528  			},
   529  			Negotiate: bench.negotiate,
   530  		}
   531  
   532  		reqBytes := dumpRequest(bench.req)
   533  
   534  		type benchReadWriter struct {
   535  			io.Reader
   536  			io.Writer
   537  		}
   538  
   539  		b.Run(bench.label, func(b *testing.B) {
   540  			conn := make([]io.ReadWriter, b.N)
   541  			for i := 0; i < b.N; i++ {
   542  				conn[i] = benchReadWriter{bytes.NewReader(reqBytes), ioutil.Discard}
   543  			}
   544  
   545  			i := new(int64)
   546  
   547  			b.ResetTimer()
   548  			b.ReportAllocs()
   549  			b.RunParallel(func(pb *testing.PB) {
   550  				for pb.Next() {
   551  					c := conn[atomic.AddInt64(i, 1)-1]
   552  					u.Upgrade(c)
   553  				}
   554  			})
   555  		})
   556  	}
   557  }
   558  
   559  func TestHttpStrSelectProtocol(t *testing.T) {
   560  	for i, test := range []struct {
   561  		header string
   562  	}{
   563  		{"jsonrpc, soap, grpc"},
   564  	} {
   565  		t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
   566  			exp := strings.Split(test.header, ",")
   567  			for i, p := range exp {
   568  				exp[i] = strings.TrimSpace(p)
   569  			}
   570  
   571  			var calls []string
   572  			strSelectProtocol(test.header, func(s string) bool {
   573  				calls = append(calls, s)
   574  				return false
   575  			})
   576  
   577  			if !reflect.DeepEqual(calls, exp) {
   578  				t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp)
   579  			}
   580  		})
   581  	}
   582  }
   583  
   584  func BenchmarkSelectProtocol(b *testing.B) {
   585  	for _, bench := range []struct {
   586  		label     string
   587  		header    string
   588  		acceptStr func(string) bool
   589  		acceptBts func([]byte) bool
   590  	}{
   591  		{
   592  			label:  "never accept",
   593  			header: "jsonrpc, soap, grpc",
   594  			acceptStr: func(s string) bool {
   595  				return len(s)%2 == 2 // never ok
   596  			},
   597  			acceptBts: func(v []byte) bool {
   598  				return len(v)%2 == 2 // never ok
   599  			},
   600  		},
   601  		{
   602  			label:     "from slice",
   603  			header:    "a, b, c, d, e, f, g",
   604  			acceptStr: SelectFromSlice([]string{"g", "f", "e", "d"}),
   605  		},
   606  		{
   607  			label:     "uniq 1024 from slise",
   608  			header:    strings.Join(randProtocols(1024, 16), ", "),
   609  			acceptStr: SelectFromSlice(randProtocols(1024, 17)),
   610  		},
   611  	} {
   612  		b.Run(fmt.Sprintf("String/%s", bench.label), func(b *testing.B) {
   613  			for i := 0; i < b.N; i++ {
   614  				strSelectProtocol(bench.header, bench.acceptStr)
   615  			}
   616  		})
   617  		if bench.acceptBts != nil {
   618  			b.Run(fmt.Sprintf("Bytes/%s", bench.label), func(b *testing.B) {
   619  				h := []byte(bench.header)
   620  				b.StartTimer()
   621  
   622  				for i := 0; i < b.N; i++ {
   623  					btsSelectProtocol(h, bench.acceptBts)
   624  				}
   625  			})
   626  		}
   627  	}
   628  }
   629  
   630  func randProtocols(n, m int) []string {
   631  	ret := make([]string, n)
   632  	bts := make([]byte, m)
   633  	uniq := map[string]bool{}
   634  	for i := 0; i < n; i++ {
   635  		for {
   636  			for j := 0; j < m; j++ {
   637  				bts[j] = byte(rand.Intn('x'-'a') + 'a')
   638  			}
   639  			str := string(bts)
   640  			if _, has := uniq[str]; !has {
   641  				ret[i] = str
   642  				break
   643  			}
   644  		}
   645  	}
   646  	return ret
   647  }
   648  
   649  func dumpRequest(req *http.Request) []byte {
   650  	bts, err := httputil.DumpRequest(req, true)
   651  	if err != nil {
   652  		panic(err)
   653  	}
   654  	return bts
   655  }
   656  
   657  func dumpResponse(res *http.Response) []byte {
   658  	if !res.Close {
   659  		for _, v := range res.Header[headerConnection] {
   660  			if v == "close" {
   661  				res.Close = true
   662  				break
   663  			}
   664  		}
   665  	}
   666  	bts, err := httputil.DumpResponse(res, true)
   667  	if err != nil {
   668  		panic(err)
   669  	}
   670  	if !res.Close {
   671  		bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1)
   672  	}
   673  
   674  	return bts
   675  }
   676  
   677  type headersBytes [][]byte
   678  
   679  func (h headersBytes) Len() int           { return len(h) }
   680  func (h headersBytes) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
   681  func (h headersBytes) Less(i, j int) bool { return bytes.Compare(h[i], h[j]) == -1 }
   682  
   683  func maskHeader(bts []byte, key, mask string) []byte {
   684  	lines := bytes.Split(bts, []byte("\r\n"))
   685  	for i, line := range lines {
   686  		pair := bytes.Split(line, []byte(": "))
   687  		if string(pair[0]) == key {
   688  			lines[i] = []byte(key + ": " + mask)
   689  		}
   690  	}
   691  	return bytes.Join(lines, []byte("\r\n"))
   692  }
   693  
   694  func sortHeaders(bts []byte) []byte {
   695  	lines := bytes.Split(bts, []byte("\r\n"))
   696  	if len(lines) <= 1 {
   697  		return bts
   698  	}
   699  	sort.Sort(headersBytes(lines[1 : len(lines)-2]))
   700  	return bytes.Join(lines, []byte("\r\n"))
   701  }
   702  
   703  //go:linkname httpPutBufioReader net/http.putBufioReader
   704  func httpPutBufioReader(*bufio.Reader)
   705  
   706  //go:linkname httpPutBufioWriter net/http.putBufioWriter
   707  func httpPutBufioWriter(*bufio.Writer)
   708  
   709  //go:linkname httpNewBufioReader net/http.newBufioReader
   710  func httpNewBufioReader(io.Reader) *bufio.Reader
   711  
   712  //go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize
   713  func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer
   714  
   715  type recorder struct {
   716  	*httptest.ResponseRecorder
   717  	hijacked bool
   718  	conn     func(*bytes.Buffer) net.Conn
   719  }
   720  
   721  func newRecorder() *recorder {
   722  	return &recorder{
   723  		ResponseRecorder: httptest.NewRecorder(),
   724  	}
   725  }
   726  
   727  func (r *recorder) Bytes() []byte {
   728  	if r.hijacked {
   729  		return r.ResponseRecorder.Body.Bytes()
   730  	}
   731  
   732  	// TODO(gobwas): remove this when support for go 1.7 will end.
   733  	resp := r.Result()
   734  	cs := strings.TrimSpace(resp.Header.Get("Content-Length"))
   735  	if n, err := strconv.ParseInt(cs, 10, 64); err == nil {
   736  		resp.ContentLength = n
   737  	} else {
   738  		resp.ContentLength = -1
   739  	}
   740  
   741  	return dumpResponse(resp)
   742  }
   743  
   744  func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) {
   745  	if r.hijacked {
   746  		err = fmt.Errorf("already hijacked")
   747  		return
   748  	}
   749  
   750  	r.hijacked = true
   751  
   752  	var buf *bytes.Buffer
   753  	if r.ResponseRecorder != nil {
   754  		buf = r.ResponseRecorder.Body
   755  	}
   756  
   757  	if r.conn != nil {
   758  		conn = r.conn(buf)
   759  	} else {
   760  		conn = stubConn{
   761  			read:  buf.Read,
   762  			write: buf.Write,
   763  			close: func() error { return nil },
   764  		}
   765  	}
   766  
   767  	// Use httpNewBufio* linked functions here to make
   768  	// benchmark more closer to real life usage.
   769  	br := httpNewBufioReader(conn)
   770  	bw := httpNewBufioWriterSize(conn, 4<<10)
   771  
   772  	brw = bufio.NewReadWriter(br, bw)
   773  
   774  	return
   775  }
   776  
   777  func mustMakeRequest(method, url string, headers http.Header) *http.Request {
   778  	req, err := http.NewRequest(method, url, nil)
   779  	if err != nil {
   780  		panic(err)
   781  	}
   782  	req.Header = headers
   783  	return req
   784  }
   785  
   786  func setProto(major, minor int, req *http.Request) *http.Request {
   787  	req.ProtoMajor = major
   788  	req.ProtoMinor = minor
   789  	return req
   790  }
   791  
   792  func withoutHeader(header string, req *http.Request) *http.Request {
   793  	if strings.EqualFold(header, "Host") {
   794  		req.URL.Host = ""
   795  		req.Host = ""
   796  	} else {
   797  		delete(req.Header, header)
   798  	}
   799  	return req
   800  }
   801  
   802  func mustMakeResponse(code int, headers http.Header) *http.Response {
   803  	res := &http.Response{
   804  		StatusCode:    code,
   805  		Status:        http.StatusText(code),
   806  		Header:        headers,
   807  		ProtoMajor:    1,
   808  		ProtoMinor:    1,
   809  		ContentLength: -1,
   810  	}
   811  	return res
   812  }
   813  
   814  func mustMakeErrResponse(code int, err error, headers http.Header) *http.Response {
   815  	// Body text.
   816  	body := err.Error()
   817  
   818  	res := &http.Response{
   819  		StatusCode: code,
   820  		Status:     http.StatusText(code),
   821  		Header: http.Header{
   822  			"Content-Type": []string{"text/plain; charset=utf-8"},
   823  		},
   824  		ProtoMajor:    1,
   825  		ProtoMinor:    1,
   826  		ContentLength: int64(len(body)),
   827  	}
   828  	res.Body = ioutil.NopCloser(
   829  		strings.NewReader(body),
   830  	)
   831  	for k, v := range headers {
   832  		res.Header[k] = v
   833  	}
   834  	return res
   835  }
   836  
   837  func mustMakeNonce() (ret []byte) {
   838  	ret = make([]byte, nonceSize)
   839  	initNonce(ret)
   840  	return
   841  }