github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/util_test.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"net/textproto"
    13  	"reflect"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  var readLineCases = []struct {
    21  	label   string
    22  	in      string
    23  	line    []byte
    24  	err     error
    25  	bufSize int
    26  }{
    27  	{
    28  		label:   "simple",
    29  		in:      "hello, world!",
    30  		line:    []byte("hello, world!"),
    31  		err:     io.EOF,
    32  		bufSize: 1024,
    33  	},
    34  	{
    35  		label:   "simple",
    36  		in:      "hello, world!\r\n",
    37  		line:    []byte("hello, world!"),
    38  		bufSize: 1024,
    39  	},
    40  	{
    41  		label:   "simple",
    42  		in:      "hello, world!\n",
    43  		line:    []byte("hello, world!"),
    44  		bufSize: 1024,
    45  	},
    46  	{
    47  		// The case where "\r\n" straddles the buffer.
    48  		label:   "straddle",
    49  		in:      "hello, world!!!\r\n...",
    50  		line:    []byte("hello, world!!!"),
    51  		bufSize: 16,
    52  	},
    53  	{
    54  		label:   "chunked",
    55  		in:      "hello, world! this is a long long line!",
    56  		line:    []byte("hello, world! this is a long long line!"),
    57  		err:     io.EOF,
    58  		bufSize: 16,
    59  	},
    60  	{
    61  		label:   "chunked",
    62  		in:      "hello, world! this is a long long line!\r\n",
    63  		line:    []byte("hello, world! this is a long long line!"),
    64  		bufSize: 16,
    65  	},
    66  }
    67  
    68  func TestReadLine(t *testing.T) {
    69  	for _, test := range readLineCases {
    70  		t.Run(test.label, func(t *testing.T) {
    71  			br := bufio.NewReaderSize(strings.NewReader(test.in), test.bufSize)
    72  			bts, err := readLine(br)
    73  			if err != test.err {
    74  				t.Errorf("unexpected error: %v; want %v", err, test.err)
    75  			}
    76  			if act, exp := bts, test.line; !bytes.Equal(act, exp) {
    77  				t.Errorf("readLine() result is %#q; want %#q", act, exp)
    78  			}
    79  		})
    80  	}
    81  }
    82  
    83  func BenchmarkReadLine(b *testing.B) {
    84  	for _, test := range readLineCases {
    85  		sr := strings.NewReader(test.in)
    86  		br := bufio.NewReaderSize(sr, test.bufSize)
    87  		b.Run(test.label, func(b *testing.B) {
    88  			for i := 0; i < b.N; i++ {
    89  				_, _ = readLine(br)
    90  				sr.Reset(test.in)
    91  				br.Reset(sr)
    92  			}
    93  		})
    94  	}
    95  }
    96  
    97  func TestUpgradeSlowClient(t *testing.T) {
    98  	for _, test := range []struct {
    99  		lim *limitWriter
   100  	}{
   101  		{
   102  			lim: &limitWriter{
   103  				Bandwidth: 100,
   104  				Period:    time.Second,
   105  				Burst:     10,
   106  			},
   107  		},
   108  		{
   109  			lim: &limitWriter{
   110  				Bandwidth: 100,
   111  				Period:    time.Second,
   112  				Burst:     100,
   113  			},
   114  		},
   115  	} {
   116  		t.Run("", func(t *testing.T) {
   117  			client, server, err := socketPair()
   118  			if err != nil {
   119  				t.Fatal(err)
   120  			}
   121  			test.lim.Dest = server
   122  
   123  			header := http.Header{
   124  				"X-Websocket-Test-1": []string{"Yes"},
   125  				"X-Websocket-Test-2": []string{"Yes"},
   126  				"X-Websocket-Test-3": []string{"Yes"},
   127  				"X-Websocket-Test-4": []string{"Yes"},
   128  			}
   129  			d := Dialer{
   130  				NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
   131  					return connWithWriter{server, test.lim}, nil
   132  				},
   133  				Header: HandshakeHeaderHTTP(header),
   134  			}
   135  			var (
   136  				expHost = "example.org"
   137  				expURI  = "/path/to/ws"
   138  			)
   139  			receivedHeader := http.Header{}
   140  			u := Upgrader{
   141  				OnRequest: func(uri []byte) error {
   142  					if u := string(uri); u != expURI {
   143  						t.Errorf(
   144  							"unexpected URI in OnRequest() callback: %q; want %q",
   145  							u, expURI,
   146  						)
   147  					}
   148  					return nil
   149  				},
   150  				OnHost: func(host []byte) error {
   151  					if h := string(host); h != expHost {
   152  						t.Errorf(
   153  							"unexpected host in OnRequest() callback: %q; want %q",
   154  							h, expHost,
   155  						)
   156  					}
   157  					return nil
   158  				},
   159  				OnHeader: func(key, value []byte) error {
   160  					receivedHeader.Add(string(key), string(value))
   161  					return nil
   162  				},
   163  			}
   164  			upgrade := make(chan error, 1)
   165  			go func() {
   166  				_, err := u.Upgrade(client)
   167  				upgrade <- err
   168  			}()
   169  
   170  			_, _, _, err = d.Dial(context.Background(), "ws://"+expHost+expURI)
   171  			if err != nil {
   172  				t.Errorf("Dial() error: %v", err)
   173  			}
   174  
   175  			if err := <-upgrade; err != nil {
   176  				t.Errorf("Upgrade() error: %v", err)
   177  			}
   178  			for key, values := range header {
   179  				act, has := receivedHeader[key]
   180  				if !has {
   181  					t.Errorf("OnHeader() was not called with %q header key", key)
   182  				}
   183  				if !reflect.DeepEqual(act, values) {
   184  					t.Errorf("OnHeader(%q) different values: %v; want %v", key, act, values)
   185  				}
   186  			}
   187  		})
   188  	}
   189  }
   190  
   191  type connWithWriter struct {
   192  	net.Conn
   193  	w io.Writer
   194  }
   195  
   196  func (w connWithWriter) Write(p []byte) (int, error) {
   197  	return w.w.Write(p)
   198  }
   199  
   200  type limitWriter struct {
   201  	Dest      io.Writer
   202  	Bandwidth int
   203  	Burst     int
   204  	Period    time.Duration
   205  
   206  	mu      sync.Mutex
   207  	cond    sync.Cond
   208  	once    sync.Once
   209  	done    chan struct{}
   210  	tickets int
   211  }
   212  
   213  func (w *limitWriter) init() {
   214  	w.once.Do(func() {
   215  		w.cond.L = &w.mu
   216  		w.done = make(chan struct{})
   217  
   218  		tick := w.Period / time.Duration(w.Bandwidth)
   219  		go func() {
   220  			t := time.NewTicker(tick)
   221  			for {
   222  				select {
   223  				case <-t.C:
   224  					w.mu.Lock()
   225  					w.tickets = w.Burst
   226  					w.mu.Unlock()
   227  					w.cond.Signal()
   228  				case <-w.done:
   229  					t.Stop()
   230  					return
   231  				}
   232  			}
   233  		}()
   234  	})
   235  }
   236  
   237  func (w *limitWriter) allow(n int) (allowed int) {
   238  	w.init()
   239  	w.mu.Lock()
   240  	defer w.mu.Unlock()
   241  	for w.tickets == 0 {
   242  		w.cond.Wait()
   243  	}
   244  	if w.tickets < 0 {
   245  		return -1
   246  	}
   247  	allowed = min(w.tickets, n)
   248  	w.tickets -= allowed
   249  	return allowed
   250  }
   251  
   252  func (w *limitWriter) Close() error {
   253  	w.init()
   254  	w.mu.Lock()
   255  	defer w.mu.Unlock()
   256  	if w.tickets < 0 {
   257  		return nil
   258  	}
   259  	w.tickets = -1
   260  	close(w.done)
   261  	w.cond.Broadcast()
   262  	return nil
   263  }
   264  
   265  func (w *limitWriter) Write(p []byte) (n int, err error) {
   266  	w.init()
   267  	for n < len(p) {
   268  		m := w.allow(len(p))
   269  		if m < 0 {
   270  			return 0, io.ErrClosedPipe
   271  		}
   272  		if _, err := w.Dest.Write(p[n : n+m]); err != nil {
   273  			return n, err
   274  		}
   275  		n += m
   276  	}
   277  	return n, nil
   278  }
   279  
   280  func socketPair() (client, server net.Conn, err error) {
   281  	ln, err := net.Listen("tcp", "localhost:")
   282  	if err != nil {
   283  		return nil, nil, err
   284  	}
   285  	type connAndError struct {
   286  		conn net.Conn
   287  		err  error
   288  	}
   289  	dial := make(chan connAndError, 1)
   290  	go func() {
   291  		conn, err := net.Dial("tcp", ln.Addr().String())
   292  		dial <- connAndError{conn, err}
   293  	}()
   294  	server, err = ln.Accept()
   295  	if err != nil {
   296  		return nil, nil, err
   297  	}
   298  	ce := <-dial
   299  	if err := ce.err; err != nil {
   300  		return nil, nil, err
   301  	}
   302  	return ce.conn, server, nil
   303  }
   304  
   305  func TestHasToken(t *testing.T) {
   306  	for i, test := range []struct {
   307  		header string
   308  		token  string
   309  		exp    bool
   310  	}{
   311  		{"Keep-Alive, Close, Upgrade", "upgrade", true},
   312  		{"Keep-Alive, Close, upgrade, hello", "upgrade", true},
   313  		{"Keep-Alive, Close,  hello", "upgrade", false},
   314  	} {
   315  		t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
   316  			if has := strHasToken(test.header, test.token); has != test.exp {
   317  				t.Errorf("hasToken(%q, %q) = %v; want %v", test.header, test.token, has, test.exp)
   318  			}
   319  		})
   320  	}
   321  }
   322  
   323  func BenchmarkHasToken(b *testing.B) {
   324  	for i, bench := range []struct {
   325  		header string
   326  		token  string
   327  	}{
   328  		{"Keep-Alive, Close, Upgrade", "upgrade"},
   329  		{"Keep-Alive, Close, upgrade, hello", "upgrade"},
   330  		{"Keep-Alive, Close,  hello", "upgrade"},
   331  	} {
   332  		b.Run(fmt.Sprintf("#%d", i), func(b *testing.B) {
   333  			for i := 0; i < b.N; i++ {
   334  				_ = strHasToken(bench.header, bench.token)
   335  			}
   336  		})
   337  	}
   338  }
   339  
   340  type equalFoldCase struct {
   341  	label string
   342  	a, b  string
   343  }
   344  
   345  var equalFoldCases = []equalFoldCase{
   346  	{"websocket", "WebSocket", "websocket"},
   347  	{"upgrade", "Upgrade", "upgrade"},
   348  	randomEqualLetters(512),
   349  	inequalAt(randomEqualLetters(512), 256),
   350  }
   351  
   352  func TestAsciiToInt(t *testing.T) {
   353  	for _, test := range []struct {
   354  		bts []byte
   355  		exp int
   356  		err bool
   357  	}{
   358  		{[]byte{'0'}, 0, false},
   359  		{[]byte{'1'}, 1, false},
   360  		{[]byte("42"), 42, false},
   361  		{[]byte("420"), 420, false},
   362  		{[]byte("010050042"), 10050042, false},
   363  	} {
   364  		t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) {
   365  			act, err := asciiToInt(test.bts)
   366  			if (test.err && err == nil) || (!test.err && err != nil) {
   367  				t.Errorf("unexpected error: %v", err)
   368  			}
   369  			if act != test.exp {
   370  				t.Errorf("asciiToInt(%v) = %v; want %v", test.bts, act, test.exp)
   371  			}
   372  		})
   373  	}
   374  }
   375  
   376  func TestBtrim(t *testing.T) {
   377  	for _, test := range []struct {
   378  		bts []byte
   379  		exp []byte
   380  	}{
   381  		{[]byte("abc"), []byte("abc")},
   382  		{[]byte(" abc"), []byte("abc")},
   383  		{[]byte("abc "), []byte("abc")},
   384  		{[]byte(" abc "), []byte("abc")},
   385  	} {
   386  		t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) {
   387  			if act := btrim(test.bts); !bytes.Equal(act, test.exp) {
   388  				t.Errorf("btrim(%v) = %v; want %v", test.bts, act, test.exp)
   389  			}
   390  		})
   391  	}
   392  }
   393  
   394  func TestBSplit3(t *testing.T) {
   395  	for _, test := range []struct {
   396  		bts  []byte
   397  		sep  byte
   398  		exp1 []byte
   399  		exp2 []byte
   400  		exp3 []byte
   401  	}{
   402  		{[]byte(""), ' ', []byte{}, nil, nil},
   403  		{[]byte("GET / HTTP/1.1"), ' ', []byte("GET"), []byte("/"), []byte("HTTP/1.1")},
   404  	} {
   405  		t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) {
   406  			b1, b2, b3 := bsplit3(test.bts, test.sep)
   407  			if !bytes.Equal(b1, test.exp1) || !bytes.Equal(b2, test.exp2) || !bytes.Equal(b3, test.exp3) {
   408  				t.Errorf(
   409  					"bsplit3(%q) = %q, %q, %q; want %q, %q, %q",
   410  					string(test.bts), string(b1), string(b2), string(b3),
   411  					string(test.exp1), string(test.exp2), string(test.exp3),
   412  				)
   413  			}
   414  		})
   415  	}
   416  }
   417  
   418  var canonicalHeaderCases = [][]byte{
   419  	[]byte("foo-"),
   420  	[]byte("-foo"),
   421  	[]byte("-"),
   422  	[]byte("foo----bar"),
   423  	[]byte("foo-bar"),
   424  	[]byte("FoO-BaR"),
   425  	[]byte("Foo-Bar"),
   426  	[]byte("sec-websocket-extensions"),
   427  }
   428  
   429  func TestCanonicalizeHeaderKey(t *testing.T) {
   430  	for _, bts := range canonicalHeaderCases {
   431  		t.Run(fmt.Sprintf("%s", string(bts)), func(t *testing.T) {
   432  			act := append([]byte(nil), bts...)
   433  			canonicalizeHeaderKey(act)
   434  
   435  			exp := strToBytes(textproto.CanonicalMIMEHeaderKey(string(bts)))
   436  
   437  			if !bytes.Equal(act, exp) {
   438  				t.Errorf(
   439  					"canonicalizeHeaderKey(%v) = %v; want %v",
   440  					string(bts), string(act), string(exp),
   441  				)
   442  			}
   443  		})
   444  	}
   445  }
   446  
   447  func BenchmarkCanonicalizeHeaderKey(b *testing.B) {
   448  	for _, bts := range canonicalHeaderCases {
   449  		b.Run(fmt.Sprintf("%s", string(bts)), func(b *testing.B) {
   450  			for i := 0; i < b.N; i++ {
   451  				canonicalizeHeaderKey(bts)
   452  			}
   453  		})
   454  	}
   455  }
   456  
   457  func randomEqualLetters(n int) (c equalFoldCase) {
   458  	c.label = fmt.Sprintf("rnd_eq_%d", n)
   459  
   460  	a, b := make([]byte, n), make([]byte, n)
   461  
   462  	for i := 0; i < n; i++ {
   463  		c := byte(rand.Intn('Z'-'A'+1) + 'A') // Random character from 'A' to 'Z'.
   464  		a[i] = c
   465  		b[i] = c | ('a' - 'A') // Swap fold.
   466  	}
   467  
   468  	c.a = string(a)
   469  	c.b = string(b)
   470  
   471  	return
   472  }
   473  
   474  func inequalAt(c equalFoldCase, i int) equalFoldCase {
   475  	bts := make([]byte, len(c.a))
   476  	copy(bts, c.a)
   477  	for {
   478  		b := byte(rand.Intn('z'-'a'+1) + 'a')
   479  		if bts[i] != b {
   480  			bts[i] = b
   481  			c.a = string(bts)
   482  			c.label = fmt.Sprintf("rnd_ineq_%d_%d", len(c.a), i)
   483  			return c
   484  		}
   485  	}
   486  }