github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/server_test.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/http/httputil"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync/atomic"
    19  	"testing"
    20  	_ "unsafe" // for go:linkname
    21  
    22  	"github.com/ezoic/httphead"
    23  )
    24  
    25  // TODO(ezoic): upgradeGenericCase with methods like configureUpgrader,
    26  // configureHTTPUpgrader.
    27  type upgradeCase struct {
    28  	label string
    29  
    30  	protocol  func(string) bool
    31  	extension func(httphead.Option) bool
    32  	onRequest func(u []byte) error
    33  	onHost    func(h []byte) error
    34  	onHeader  func(k, v []byte) error
    35  
    36  	nonce        []byte
    37  	removeSecKey bool
    38  	badSecKey    bool
    39  	secKeyHeader string
    40  
    41  	req *http.Request
    42  	res *http.Response
    43  	hs  Handshake
    44  	err error
    45  }
    46  
    47  var upgradeCases = []upgradeCase{
    48  	{
    49  		label: "base",
    50  		nonce: mustMakeNonce(),
    51  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    52  			headerUpgrade:    []string{"websocket"},
    53  			headerConnection: []string{"Upgrade"},
    54  			headerSecVersion: []string{"13"},
    55  		}),
    56  		res: mustMakeResponse(101, http.Header{
    57  			headerUpgrade:    []string{"websocket"},
    58  			headerConnection: []string{"Upgrade"},
    59  		}),
    60  	},
    61  	{
    62  		label:        "base_canonical",
    63  		nonce:        mustMakeNonce(),
    64  		secKeyHeader: headerSecKeyCanonical,
    65  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    66  			headerUpgrade:             []string{"websocket"},
    67  			headerConnection:          []string{"Upgrade"},
    68  			headerSecVersionCanonical: []string{"13"},
    69  		}),
    70  		res: mustMakeResponse(101, http.Header{
    71  			headerUpgrade:    []string{"websocket"},
    72  			headerConnection: []string{"Upgrade"},
    73  		}),
    74  	},
    75  	{
    76  		label:        "lowercase_headers",
    77  		nonce:        mustMakeNonce(),
    78  		secKeyHeader: strings.ToLower(headerSecKey),
    79  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    80  			strings.ToLower(headerUpgrade):    []string{"websocket"},
    81  			strings.ToLower(headerConnection): []string{"Upgrade"},
    82  			strings.ToLower(headerSecVersion): []string{"13"},
    83  		}),
    84  		res: mustMakeResponse(101, http.Header{
    85  			headerUpgrade:    []string{"websocket"},
    86  			headerConnection: []string{"Upgrade"},
    87  		}),
    88  	},
    89  	{
    90  		label:    "uppercase",
    91  		protocol: func(sub string) bool { return true },
    92  		nonce:    mustMakeNonce(),
    93  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
    94  			headerUpgrade:    []string{"WEBSOCKET"},
    95  			headerConnection: []string{"UPGRADE"},
    96  			headerSecVersion: []string{"13"},
    97  		}),
    98  		res: mustMakeResponse(101, http.Header{
    99  			headerUpgrade:    []string{"websocket"},
   100  			headerConnection: []string{"Upgrade"},
   101  		}),
   102  	},
   103  	{
   104  		label:    "subproto",
   105  		protocol: SelectFromSlice([]string{"b", "d"}),
   106  		nonce:    mustMakeNonce(),
   107  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   108  			headerUpgrade:     []string{"websocket"},
   109  			headerConnection:  []string{"Upgrade"},
   110  			headerSecVersion:  []string{"13"},
   111  			headerSecProtocol: []string{"a", "b", "c", "d"},
   112  		}),
   113  		res: mustMakeResponse(101, http.Header{
   114  			headerUpgrade:     []string{"websocket"},
   115  			headerConnection:  []string{"Upgrade"},
   116  			headerSecProtocol: []string{"b"},
   117  		}),
   118  		hs: Handshake{Protocol: "b"},
   119  	},
   120  	{
   121  		label:        "subproto_lowercase_headers",
   122  		protocol:     SelectFromSlice([]string{"b", "d"}),
   123  		nonce:        mustMakeNonce(),
   124  		secKeyHeader: strings.ToLower(headerSecKey),
   125  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   126  			strings.ToLower(headerUpgrade):     []string{"websocket"},
   127  			strings.ToLower(headerConnection):  []string{"Upgrade"},
   128  			strings.ToLower(headerSecVersion):  []string{"13"},
   129  			strings.ToLower(headerSecProtocol): []string{"a", "b", "c", "d"},
   130  		}),
   131  		res: mustMakeResponse(101, http.Header{
   132  			headerUpgrade:     []string{"websocket"},
   133  			headerConnection:  []string{"Upgrade"},
   134  			headerSecProtocol: []string{"b"},
   135  		}),
   136  		hs: Handshake{Protocol: "b"},
   137  	},
   138  	{
   139  		label:    "subproto_comma",
   140  		protocol: SelectFromSlice([]string{"b", "d"}),
   141  		nonce:    mustMakeNonce(),
   142  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   143  			headerUpgrade:     []string{"websocket"},
   144  			headerConnection:  []string{"Upgrade"},
   145  			headerSecVersion:  []string{"13"},
   146  			headerSecProtocol: []string{"a, b, c, d"},
   147  		}),
   148  		res: mustMakeResponse(101, http.Header{
   149  			headerUpgrade:     []string{"websocket"},
   150  			headerConnection:  []string{"Upgrade"},
   151  			headerSecProtocol: []string{"b"},
   152  		}),
   153  		hs: Handshake{Protocol: "b"},
   154  	},
   155  	{
   156  		extension: func(opt httphead.Option) bool {
   157  			switch string(opt.Name) {
   158  			case "b", "d":
   159  				return true
   160  			default:
   161  				return false
   162  			}
   163  		},
   164  		nonce: mustMakeNonce(),
   165  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   166  			headerUpgrade:       []string{"websocket"},
   167  			headerConnection:    []string{"Upgrade"},
   168  			headerSecVersion:    []string{"13"},
   169  			headerSecExtensions: []string{"a;foo=1", "b;bar=2", "c", "d;baz=3"},
   170  		}),
   171  		res: mustMakeResponse(101, http.Header{
   172  			headerUpgrade:       []string{"websocket"},
   173  			headerConnection:    []string{"Upgrade"},
   174  			headerSecExtensions: []string{"b;bar=2,d;baz=3"},
   175  		}),
   176  		hs: Handshake{
   177  			Extensions: []httphead.Option{
   178  				httphead.NewOption("b", map[string]string{
   179  					"bar": "2",
   180  				}),
   181  				httphead.NewOption("d", map[string]string{
   182  					"baz": "3",
   183  				}),
   184  			},
   185  		},
   186  	},
   187  
   188  	// Error cases.
   189  	// ------------
   190  
   191  	{
   192  		label: "bad_http_method",
   193  		nonce: mustMakeNonce(),
   194  		req: mustMakeRequest("POST", "ws://example.org", http.Header{
   195  			headerUpgrade:    []string{"websocket"},
   196  			headerConnection: []string{"Upgrade"},
   197  			headerSecVersion: []string{"13"},
   198  		}),
   199  		res: mustMakeErrResponse(405, ErrHandshakeBadMethod, nil),
   200  		err: ErrHandshakeBadMethod,
   201  	},
   202  	{
   203  		label: "bad_http_proto",
   204  		nonce: mustMakeNonce(),
   205  		req: setProto(1, 0, mustMakeRequest("GET", "ws://example.org", http.Header{
   206  			headerUpgrade:    []string{"websocket"},
   207  			headerConnection: []string{"Upgrade"},
   208  			headerSecVersion: []string{"13"},
   209  		})),
   210  		res: mustMakeErrResponse(505, ErrHandshakeBadProtocol, nil),
   211  		err: ErrHandshakeBadProtocol,
   212  	},
   213  	{
   214  		label: "bad_host",
   215  		nonce: mustMakeNonce(),
   216  		req: withoutHeader("Host", mustMakeRequest("GET", "ws://example.org", http.Header{
   217  			headerUpgrade:    []string{"websocket"},
   218  			headerConnection: []string{"Upgrade"},
   219  			headerSecVersion: []string{"13"},
   220  		})),
   221  		res: mustMakeErrResponse(400, ErrHandshakeBadHost, nil),
   222  		err: ErrHandshakeBadHost,
   223  	},
   224  	{
   225  		label: "bad_upgrade",
   226  		nonce: mustMakeNonce(),
   227  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   228  			headerConnection: []string{"Upgrade"},
   229  			headerSecVersion: []string{"13"},
   230  		}),
   231  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   232  		err: ErrHandshakeBadUpgrade,
   233  	},
   234  	{
   235  		label: "bad_upgrade",
   236  		nonce: mustMakeNonce(),
   237  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   238  			"X-Custom-Header": []string{"value"},
   239  			headerConnection:  []string{"Upgrade"},
   240  			headerSecVersion:  []string{"13"},
   241  		}),
   242  
   243  		onRequest: func([]byte) error { return nil },
   244  		onHost:    func([]byte) error { return nil },
   245  		onHeader:  func(k, v []byte) error { return nil },
   246  
   247  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   248  		err: ErrHandshakeBadUpgrade,
   249  	},
   250  	{
   251  		label: "bad_upgrade",
   252  		nonce: mustMakeNonce(),
   253  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   254  			headerUpgrade:    []string{"not-websocket"},
   255  			headerConnection: []string{"Upgrade"},
   256  			headerSecVersion: []string{"13"},
   257  		}),
   258  		res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil),
   259  		err: ErrHandshakeBadUpgrade,
   260  	},
   261  	{
   262  		label: "bad_connection",
   263  		nonce: mustMakeNonce(),
   264  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   265  			headerUpgrade:    []string{"websocket"},
   266  			headerSecVersion: []string{"13"},
   267  		}),
   268  		res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil),
   269  		err: ErrHandshakeBadConnection,
   270  	},
   271  	{
   272  		label: "bad_connection",
   273  		nonce: mustMakeNonce(),
   274  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   275  			headerUpgrade:    []string{"websocket"},
   276  			headerConnection: []string{"not-upgrade"},
   277  			headerSecVersion: []string{"13"},
   278  		}),
   279  		res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil),
   280  		err: ErrHandshakeBadConnection,
   281  	},
   282  	{
   283  		label: "bad_sec_version_x",
   284  		nonce: mustMakeNonce(),
   285  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   286  			headerUpgrade:    []string{"websocket"},
   287  			headerConnection: []string{"Upgrade"},
   288  		}),
   289  		res: mustMakeErrResponse(400, ErrHandshakeBadSecVersion, nil),
   290  		err: ErrHandshakeBadSecVersion,
   291  	},
   292  	{
   293  		label: "bad_sec_version",
   294  		nonce: mustMakeNonce(),
   295  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   296  			headerUpgrade:    []string{"websocket"},
   297  			headerConnection: []string{"upgrade"},
   298  			headerSecVersion: []string{"15"},
   299  		}),
   300  		res: mustMakeErrResponse(426, ErrHandshakeBadSecVersion, http.Header{
   301  			headerSecVersion: []string{"13"},
   302  		}),
   303  		err: ErrHandshakeUpgradeRequired,
   304  	},
   305  	{
   306  		label:        "bad_sec_key",
   307  		nonce:        mustMakeNonce(),
   308  		removeSecKey: true,
   309  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   310  			headerUpgrade:    []string{"websocket"},
   311  			headerConnection: []string{"Upgrade"},
   312  			headerSecVersion: []string{"13"},
   313  		}),
   314  		res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil),
   315  		err: ErrHandshakeBadSecKey,
   316  	},
   317  	{
   318  		label:     "bad_sec_key",
   319  		nonce:     mustMakeNonce(),
   320  		badSecKey: true,
   321  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   322  			headerUpgrade:    []string{"websocket"},
   323  			headerConnection: []string{"Upgrade"},
   324  			headerSecVersion: []string{"13"},
   325  		}),
   326  		res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil),
   327  		err: ErrHandshakeBadSecKey,
   328  	},
   329  	{
   330  		label: "bad_ws_extension",
   331  		nonce: mustMakeNonce(),
   332  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   333  			headerUpgrade:       []string{"websocket"},
   334  			headerConnection:    []string{"Upgrade"},
   335  			headerSecVersion:    []string{"13"},
   336  			headerSecExtensions: []string{"=["},
   337  		}),
   338  		extension: func(opt httphead.Option) bool {
   339  			return false
   340  		},
   341  		res: mustMakeErrResponse(400, ErrMalformedRequest, nil),
   342  		err: ErrMalformedRequest,
   343  	},
   344  	{
   345  		label: "bad_subprotocol",
   346  		nonce: mustMakeNonce(),
   347  		req: mustMakeRequest("GET", "ws://example.org", http.Header{
   348  			headerUpgrade:     []string{"websocket"},
   349  			headerConnection:  []string{"Upgrade"},
   350  			headerSecVersion:  []string{"13"},
   351  			headerSecProtocol: []string{"=["},
   352  		}),
   353  		protocol: func(string) bool {
   354  			return false
   355  		},
   356  		res: mustMakeErrResponse(400, ErrMalformedRequest, nil),
   357  		err: ErrMalformedRequest,
   358  	},
   359  }
   360  
   361  func TestHTTPUpgrader(t *testing.T) {
   362  	for _, test := range upgradeCases {
   363  		t.Run(test.label, func(t *testing.T) {
   364  			if !test.removeSecKey {
   365  				nonce := test.nonce
   366  				if test.badSecKey {
   367  					nonce = nonce[:nonceSize-1]
   368  				}
   369  				if test.secKeyHeader == "" {
   370  					test.secKeyHeader = headerSecKey
   371  				}
   372  				test.req.Header[test.secKeyHeader] = []string{string(nonce)}
   373  			}
   374  			if test.err == nil {
   375  				test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
   376  			}
   377  
   378  			// Need to emulate http server read request for truth test.
   379  			//
   380  			// We use dumpRequest here because test.req.Write is always send
   381  			// http/1.1 proto version, that does not fits all our testing
   382  			// cases.
   383  			reqBytes := dumpRequest(test.req)
   384  			req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqBytes)))
   385  			if err != nil {
   386  				t.Fatal(err)
   387  			}
   388  
   389  			res := newRecorder()
   390  
   391  			u := HTTPUpgrader{
   392  				Protocol:  test.protocol,
   393  				Extension: test.extension,
   394  			}
   395  			_, _, hs, err := u.Upgrade(req, res)
   396  			if test.err != err {
   397  				t.Errorf(
   398  					"expected error to be '%v', got '%v';\non request:\n====\n%s\n====",
   399  					test.err, err, dumpRequest(req),
   400  				)
   401  				return
   402  			}
   403  
   404  			actRespBts := sortHeaders(res.Bytes())
   405  			expRespBts := sortHeaders(dumpResponse(test.res))
   406  			if !bytes.Equal(actRespBts, expRespBts) {
   407  				t.Errorf(
   408  					"unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====",
   409  					actRespBts, expRespBts, dumpRequest(test.req),
   410  				)
   411  				return
   412  			}
   413  
   414  			if act, exp := hs.Protocol, test.hs.Protocol; act != exp {
   415  				t.Errorf("handshake protocol is %q want %q", act, exp)
   416  			}
   417  			if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp {
   418  				t.Errorf("handshake got %d extensions; want %d", act, exp)
   419  			} else {
   420  				for i := 0; i < act; i++ {
   421  					if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) {
   422  						t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp)
   423  					}
   424  				}
   425  			}
   426  		})
   427  	}
   428  }
   429  
   430  func TestUpgrader(t *testing.T) {
   431  	for _, test := range upgradeCases {
   432  		t.Run(test.label, func(t *testing.T) {
   433  			if !test.removeSecKey {
   434  				nonce := test.nonce[:]
   435  				if test.badSecKey {
   436  					nonce = nonce[:nonceSize-1]
   437  				}
   438  				test.req.Header[headerSecKey] = []string{string(nonce)}
   439  			}
   440  			if test.err == nil {
   441  				test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
   442  			}
   443  
   444  			u := Upgrader{
   445  				Protocol: func(p []byte) bool {
   446  					return test.protocol(string(p))
   447  				},
   448  				Extension: func(e httphead.Option) bool {
   449  					return test.extension(e)
   450  				},
   451  				OnHeader:  test.onHeader,
   452  				OnRequest: test.onRequest,
   453  			}
   454  
   455  			// We use dumpRequest here because test.req.Write is always send
   456  			// http/1.1 proto version, that does not fits all our testing
   457  			// cases.
   458  			reqBytes := dumpRequest(test.req)
   459  			conn := bytes.NewBuffer(reqBytes)
   460  
   461  			hs, err := u.Upgrade(conn)
   462  			if test.err != err {
   463  
   464  				t.Errorf("expected error to be '%v', got '%v'", test.err, err)
   465  				return
   466  			}
   467  
   468  			actRespBts := sortHeaders(conn.Bytes())
   469  			expRespBts := sortHeaders(dumpResponse(test.res))
   470  			if !bytes.Equal(actRespBts, expRespBts) {
   471  				t.Errorf(
   472  					"unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====",
   473  					actRespBts, expRespBts, dumpRequest(test.req),
   474  				)
   475  				return
   476  			}
   477  
   478  			if act, exp := hs.Protocol, test.hs.Protocol; act != exp {
   479  				t.Errorf("handshake protocol is %q want %q", act, exp)
   480  			}
   481  			if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp {
   482  				t.Errorf("handshake got %d extensions; want %d", act, exp)
   483  			} else {
   484  				for i := 0; i < act; i++ {
   485  					if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) {
   486  						t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp)
   487  					}
   488  				}
   489  			}
   490  		})
   491  	}
   492  }
   493  
   494  func BenchmarkHTTPUpgrader(b *testing.B) {
   495  	for _, bench := range upgradeCases {
   496  		bench.req.Header.Set(headerSecKey, string(bench.nonce[:]))
   497  
   498  		u := HTTPUpgrader{
   499  			Protocol:  bench.protocol,
   500  			Extension: bench.extension,
   501  		}
   502  
   503  		b.Run(bench.label, func(b *testing.B) {
   504  			res := make([]http.ResponseWriter, b.N)
   505  			for i := 0; i < b.N; i++ {
   506  				res[i] = newRecorder()
   507  			}
   508  
   509  			i := new(int64)
   510  
   511  			b.ResetTimer()
   512  			b.ReportAllocs()
   513  			b.RunParallel(func(pb *testing.PB) {
   514  				for pb.Next() {
   515  					w := res[atomic.AddInt64(i, 1)-1]
   516  					u.Upgrade(bench.req, w)
   517  				}
   518  			})
   519  		})
   520  	}
   521  }
   522  
   523  func BenchmarkUpgrader(b *testing.B) {
   524  	for _, bench := range upgradeCases {
   525  		bench.req.Header.Set(headerSecKey, string(bench.nonce[:]))
   526  
   527  		u := Upgrader{
   528  			Protocol: func(p []byte) bool {
   529  				return bench.protocol(btsToString(p))
   530  			},
   531  			Extension: func(e httphead.Option) bool {
   532  				return bench.extension(e)
   533  			},
   534  		}
   535  
   536  		reqBytes := dumpRequest(bench.req)
   537  
   538  		type benchReadWriter struct {
   539  			io.Reader
   540  			io.Writer
   541  		}
   542  
   543  		b.Run(bench.label, func(b *testing.B) {
   544  			conn := make([]io.ReadWriter, b.N)
   545  			for i := 0; i < b.N; i++ {
   546  				conn[i] = benchReadWriter{bytes.NewReader(reqBytes), ioutil.Discard}
   547  			}
   548  
   549  			i := new(int64)
   550  
   551  			b.ResetTimer()
   552  			b.ReportAllocs()
   553  			b.RunParallel(func(pb *testing.PB) {
   554  				for pb.Next() {
   555  					c := conn[atomic.AddInt64(i, 1)-1]
   556  					u.Upgrade(c)
   557  				}
   558  			})
   559  		})
   560  	}
   561  }
   562  
   563  func TestHttpStrSelectProtocol(t *testing.T) {
   564  	for i, test := range []struct {
   565  		header string
   566  	}{
   567  		{"jsonrpc, soap, grpc"},
   568  	} {
   569  		t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
   570  			exp := strings.Split(test.header, ",")
   571  			for i, p := range exp {
   572  				exp[i] = strings.TrimSpace(p)
   573  			}
   574  
   575  			var calls []string
   576  			strSelectProtocol(test.header, func(s string) bool {
   577  				calls = append(calls, s)
   578  				return false
   579  			})
   580  
   581  			if !reflect.DeepEqual(calls, exp) {
   582  				t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp)
   583  			}
   584  		})
   585  	}
   586  }
   587  
   588  func BenchmarkSelectProtocol(b *testing.B) {
   589  	for _, bench := range []struct {
   590  		label     string
   591  		header    string
   592  		acceptStr func(string) bool
   593  		acceptBts func([]byte) bool
   594  	}{
   595  		{
   596  			label:  "never accept",
   597  			header: "jsonrpc, soap, grpc",
   598  			acceptStr: func(s string) bool {
   599  				return len(s)%2 == 2 // never ok
   600  			},
   601  			acceptBts: func(v []byte) bool {
   602  				return len(v)%2 == 2 // never ok
   603  			},
   604  		},
   605  		{
   606  			label:     "from slice",
   607  			header:    "a, b, c, d, e, f, g",
   608  			acceptStr: SelectFromSlice([]string{"g", "f", "e", "d"}),
   609  		},
   610  		{
   611  			label:     "uniq 1024 from slise",
   612  			header:    strings.Join(randProtocols(1024, 16), ", "),
   613  			acceptStr: SelectFromSlice(randProtocols(1024, 17)),
   614  		},
   615  	} {
   616  		b.Run(fmt.Sprintf("String/%s", bench.label), func(b *testing.B) {
   617  			for i := 0; i < b.N; i++ {
   618  				strSelectProtocol(bench.header, bench.acceptStr)
   619  			}
   620  		})
   621  		if bench.acceptBts != nil {
   622  			b.Run(fmt.Sprintf("Bytes/%s", bench.label), func(b *testing.B) {
   623  				h := []byte(bench.header)
   624  				b.StartTimer()
   625  
   626  				for i := 0; i < b.N; i++ {
   627  					btsSelectProtocol(h, bench.acceptBts)
   628  				}
   629  			})
   630  		}
   631  	}
   632  }
   633  
   634  func randProtocols(n, m int) []string {
   635  	ret := make([]string, n)
   636  	bts := make([]byte, m)
   637  	uniq := map[string]bool{}
   638  	for i := 0; i < n; i++ {
   639  		for {
   640  			for j := 0; j < m; j++ {
   641  				bts[j] = byte(rand.Intn('x'-'a') + 'a')
   642  			}
   643  			str := string(bts)
   644  			if _, has := uniq[str]; !has {
   645  				ret[i] = str
   646  				break
   647  			}
   648  		}
   649  	}
   650  	return ret
   651  }
   652  
   653  func dumpRequest(req *http.Request) []byte {
   654  	bts, err := httputil.DumpRequest(req, true)
   655  	if err != nil {
   656  		panic(err)
   657  	}
   658  	return bts
   659  }
   660  
   661  func dumpResponse(res *http.Response) []byte {
   662  	if !res.Close {
   663  		for _, v := range res.Header[headerConnection] {
   664  			if v == "close" {
   665  				res.Close = true
   666  				break
   667  			}
   668  		}
   669  	}
   670  	bts, err := httputil.DumpResponse(res, true)
   671  	if err != nil {
   672  		panic(err)
   673  	}
   674  	if !res.Close {
   675  		bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1)
   676  	}
   677  
   678  	return bts
   679  }
   680  
   681  type headersBytes [][]byte
   682  
   683  func (h headersBytes) Len() int           { return len(h) }
   684  func (h headersBytes) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
   685  func (h headersBytes) Less(i, j int) bool { return bytes.Compare(h[i], h[j]) == -1 }
   686  
   687  func maskHeader(bts []byte, key, mask string) []byte {
   688  	lines := bytes.Split(bts, []byte("\r\n"))
   689  	for i, line := range lines {
   690  		pair := bytes.Split(line, []byte(": "))
   691  		if string(pair[0]) == key {
   692  			lines[i] = []byte(key + ": " + mask)
   693  		}
   694  	}
   695  	return bytes.Join(lines, []byte("\r\n"))
   696  }
   697  
   698  func sortHeaders(bts []byte) []byte {
   699  	lines := bytes.Split(bts, []byte("\r\n"))
   700  	if len(lines) <= 1 {
   701  		return bts
   702  	}
   703  	sort.Sort(headersBytes(lines[1 : len(lines)-2]))
   704  	return bytes.Join(lines, []byte("\r\n"))
   705  }
   706  
   707  //go:linkname httpPutBufioReader net/http.putBufioReader
   708  func httpPutBufioReader(*bufio.Reader)
   709  
   710  //go:linkname httpPutBufioWriter net/http.putBufioWriter
   711  func httpPutBufioWriter(*bufio.Writer)
   712  
   713  //go:linkname httpNewBufioReader net/http.newBufioReader
   714  func httpNewBufioReader(io.Reader) *bufio.Reader
   715  
   716  //go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize
   717  func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer
   718  
   719  type recorder struct {
   720  	*httptest.ResponseRecorder
   721  	hijacked bool
   722  	conn     func(*bytes.Buffer) net.Conn
   723  }
   724  
   725  func newRecorder() *recorder {
   726  	return &recorder{
   727  		ResponseRecorder: httptest.NewRecorder(),
   728  	}
   729  }
   730  
   731  func (r *recorder) Bytes() []byte {
   732  	if r.hijacked {
   733  		return r.ResponseRecorder.Body.Bytes()
   734  	}
   735  
   736  	// TODO(ezoic): remove this when support for go 1.7 will end.
   737  	resp := r.Result()
   738  	cs := strings.TrimSpace(resp.Header.Get("Content-Length"))
   739  	if n, err := strconv.ParseInt(cs, 10, 64); err == nil {
   740  		resp.ContentLength = n
   741  	} else {
   742  		resp.ContentLength = -1
   743  	}
   744  
   745  	return dumpResponse(resp)
   746  }
   747  
   748  func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) {
   749  	if r.hijacked {
   750  		err = fmt.Errorf("already hijacked")
   751  		return
   752  	}
   753  
   754  	r.hijacked = true
   755  
   756  	var buf *bytes.Buffer
   757  	if r.ResponseRecorder != nil {
   758  		buf = r.ResponseRecorder.Body
   759  	}
   760  
   761  	if r.conn != nil {
   762  		conn = r.conn(buf)
   763  	} else {
   764  		conn = stubConn{
   765  			read:  buf.Read,
   766  			write: buf.Write,
   767  			close: func() error { return nil },
   768  		}
   769  	}
   770  
   771  	// Use httpNewBufio* linked functions here to make
   772  	// benchmark more closer to real life usage.
   773  	br := httpNewBufioReader(conn)
   774  	bw := httpNewBufioWriterSize(conn, 4<<10)
   775  
   776  	brw = bufio.NewReadWriter(br, bw)
   777  
   778  	return
   779  }
   780  
   781  func mustMakeRequest(method, url string, headers http.Header) *http.Request {
   782  	req, err := http.NewRequest(method, url, nil)
   783  	if err != nil {
   784  		panic(err)
   785  	}
   786  	req.Header = headers
   787  	return req
   788  }
   789  
   790  func setProto(major, minor int, req *http.Request) *http.Request {
   791  	req.ProtoMajor = major
   792  	req.ProtoMinor = minor
   793  	return req
   794  }
   795  
   796  func withoutHeader(header string, req *http.Request) *http.Request {
   797  	if strings.EqualFold(header, "Host") {
   798  		req.URL.Host = ""
   799  		req.Host = ""
   800  	} else {
   801  		delete(req.Header, header)
   802  	}
   803  	return req
   804  }
   805  
   806  func mustMakeResponse(code int, headers http.Header) *http.Response {
   807  	res := &http.Response{
   808  		StatusCode:    code,
   809  		Status:        http.StatusText(code),
   810  		Header:        headers,
   811  		ProtoMajor:    1,
   812  		ProtoMinor:    1,
   813  		ContentLength: -1,
   814  	}
   815  	return res
   816  }
   817  
   818  func mustMakeErrResponse(code int, err error, headers http.Header) *http.Response {
   819  	// Body text.
   820  	body := err.Error()
   821  
   822  	res := &http.Response{
   823  		StatusCode: code,
   824  		Status:     http.StatusText(code),
   825  		Header: http.Header{
   826  			"Content-Type": []string{"text/plain; charset=utf-8"},
   827  		},
   828  		ProtoMajor:    1,
   829  		ProtoMinor:    1,
   830  		ContentLength: int64(len(body)),
   831  	}
   832  	res.Body = ioutil.NopCloser(
   833  		strings.NewReader(body),
   834  	)
   835  	for k, v := range headers {
   836  		res.Header[k] = v
   837  	}
   838  	return res
   839  }
   840  
   841  func mustMakeNonce() (ret []byte) {
   842  	ret = make([]byte, nonceSize)
   843  	initNonce(ret)
   844  	return
   845  }