github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/gorilla/websocket/conn_test.go (about)

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"net"
    15  	"reflect"
    16  	"testing"
    17  	"testing/iotest"
    18  	"time"
    19  )
    20  
    21  var _ net.Error = errWriteTimeout
    22  
    23  type fakeNetConn struct {
    24  	io.Reader
    25  	io.Writer
    26  }
    27  
    28  func (c fakeNetConn) Close() error                       { return nil }
    29  func (c fakeNetConn) LocalAddr() net.Addr                { return nil }
    30  func (c fakeNetConn) RemoteAddr() net.Addr               { return nil }
    31  func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
    32  func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
    33  func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
    34  
    35  func TestFraming(t *testing.T) {
    36  	frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
    37  	var readChunkers = []struct {
    38  		name string
    39  		f    func(io.Reader) io.Reader
    40  	}{
    41  		{"half", iotest.HalfReader},
    42  		{"one", iotest.OneByteReader},
    43  		{"asis", func(r io.Reader) io.Reader { return r }},
    44  	}
    45  
    46  	writeBuf := make([]byte, 65537)
    47  	for i := range writeBuf {
    48  		writeBuf[i] = byte(i)
    49  	}
    50  
    51  	for _, isServer := range []bool{true, false} {
    52  		for _, chunker := range readChunkers {
    53  
    54  			var connBuf bytes.Buffer
    55  			wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
    56  			rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
    57  
    58  			for _, n := range frameSizes {
    59  				for _, iocopy := range []bool{true, false} {
    60  					name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
    61  
    62  					w, err := wc.NextWriter(TextMessage)
    63  					if err != nil {
    64  						t.Errorf("%s: wc.NextWriter() returned %v", name, err)
    65  						continue
    66  					}
    67  					var nn int
    68  					if iocopy {
    69  						var n64 int64
    70  						n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
    71  						nn = int(n64)
    72  					} else {
    73  						nn, err = w.Write(writeBuf[:n])
    74  					}
    75  					if err != nil || nn != n {
    76  						t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
    77  						continue
    78  					}
    79  					err = w.Close()
    80  					if err != nil {
    81  						t.Errorf("%s: w.Close() returned %v", name, err)
    82  						continue
    83  					}
    84  
    85  					opCode, r, err := rc.NextReader()
    86  					if err != nil || opCode != TextMessage {
    87  						t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
    88  						continue
    89  					}
    90  					rbuf, err := ioutil.ReadAll(r)
    91  					if err != nil {
    92  						t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
    93  						continue
    94  					}
    95  
    96  					if len(rbuf) != n {
    97  						t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
    98  						continue
    99  					}
   100  
   101  					for i, b := range rbuf {
   102  						if byte(i) != b {
   103  							t.Errorf("%s: bad byte at offset %d", name, i)
   104  							break
   105  						}
   106  					}
   107  				}
   108  			}
   109  		}
   110  	}
   111  }
   112  
   113  func TestControl(t *testing.T) {
   114  	const message = "this is a ping/pong messsage"
   115  	for _, isServer := range []bool{true, false} {
   116  		for _, isWriteControl := range []bool{true, false} {
   117  			name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
   118  			var connBuf bytes.Buffer
   119  			wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
   120  			rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
   121  			if isWriteControl {
   122  				wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
   123  			} else {
   124  				w, err := wc.NextWriter(PongMessage)
   125  				if err != nil {
   126  					t.Errorf("%s: wc.NextWriter() returned %v", name, err)
   127  					continue
   128  				}
   129  				if _, err := w.Write([]byte(message)); err != nil {
   130  					t.Errorf("%s: w.Write() returned %v", name, err)
   131  					continue
   132  				}
   133  				if err := w.Close(); err != nil {
   134  					t.Errorf("%s: w.Close() returned %v", name, err)
   135  					continue
   136  				}
   137  				var actualMessage string
   138  				rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
   139  				rc.NextReader()
   140  				if actualMessage != message {
   141  					t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
   142  					continue
   143  				}
   144  			}
   145  		}
   146  	}
   147  }
   148  
   149  func TestCloseBeforeFinalFrame(t *testing.T) {
   150  	const bufSize = 512
   151  
   152  	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
   153  
   154  	var b1, b2 bytes.Buffer
   155  	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
   156  	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
   157  
   158  	w, _ := wc.NextWriter(BinaryMessage)
   159  	w.Write(make([]byte, bufSize+bufSize/2))
   160  	wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
   161  	w.Close()
   162  
   163  	op, r, err := rc.NextReader()
   164  	if op != BinaryMessage || err != nil {
   165  		t.Fatalf("NextReader() returned %d, %v", op, err)
   166  	}
   167  	_, err = io.Copy(ioutil.Discard, r)
   168  	if !reflect.DeepEqual(err, expectedErr) {
   169  		t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
   170  	}
   171  	_, _, err = rc.NextReader()
   172  	if !reflect.DeepEqual(err, expectedErr) {
   173  		t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
   174  	}
   175  }
   176  
   177  func TestEOFBeforeFinalFrame(t *testing.T) {
   178  	const bufSize = 512
   179  
   180  	var b1, b2 bytes.Buffer
   181  	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
   182  	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
   183  
   184  	w, _ := wc.NextWriter(BinaryMessage)
   185  	w.Write(make([]byte, bufSize+bufSize/2))
   186  
   187  	op, r, err := rc.NextReader()
   188  	if op != BinaryMessage || err != nil {
   189  		t.Fatalf("NextReader() returned %d, %v", op, err)
   190  	}
   191  	_, err = io.Copy(ioutil.Discard, r)
   192  	if err != errUnexpectedEOF {
   193  		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
   194  	}
   195  	_, _, err = rc.NextReader()
   196  	if err != errUnexpectedEOF {
   197  		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
   198  	}
   199  }
   200  
   201  func TestReadLimit(t *testing.T) {
   202  
   203  	const readLimit = 512
   204  	message := make([]byte, readLimit+1)
   205  
   206  	var b1, b2 bytes.Buffer
   207  	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
   208  	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
   209  	rc.SetReadLimit(readLimit)
   210  
   211  	// Send message at the limit with interleaved pong.
   212  	w, _ := wc.NextWriter(BinaryMessage)
   213  	w.Write(message[:readLimit-1])
   214  	wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
   215  	w.Write(message[:1])
   216  	w.Close()
   217  
   218  	// Send message larger than the limit.
   219  	wc.WriteMessage(BinaryMessage, message[:readLimit+1])
   220  
   221  	op, _, err := rc.NextReader()
   222  	if op != BinaryMessage || err != nil {
   223  		t.Fatalf("1: NextReader() returned %d, %v", op, err)
   224  	}
   225  	op, r, err := rc.NextReader()
   226  	if op != BinaryMessage || err != nil {
   227  		t.Fatalf("2: NextReader() returned %d, %v", op, err)
   228  	}
   229  	_, err = io.Copy(ioutil.Discard, r)
   230  	if err != ErrReadLimit {
   231  		t.Fatalf("io.Copy() returned %v", err)
   232  	}
   233  }
   234  
   235  func TestUnderlyingConn(t *testing.T) {
   236  	var b1, b2 bytes.Buffer
   237  	fc := fakeNetConn{Reader: &b1, Writer: &b2}
   238  	c := newConn(fc, true, 1024, 1024)
   239  	ul := c.UnderlyingConn()
   240  	if ul != fc {
   241  		t.Fatalf("Underlying conn is not what it should be.")
   242  	}
   243  }
   244  
   245  func TestBufioReadBytes(t *testing.T) {
   246  
   247  	// Test calling bufio.ReadBytes for value longer than read buffer size.
   248  
   249  	m := make([]byte, 512)
   250  	m[len(m)-1] = '\n'
   251  
   252  	var b1, b2 bytes.Buffer
   253  	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
   254  	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
   255  
   256  	w, _ := wc.NextWriter(BinaryMessage)
   257  	w.Write(m)
   258  	w.Close()
   259  
   260  	op, r, err := rc.NextReader()
   261  	if op != BinaryMessage || err != nil {
   262  		t.Fatalf("NextReader() returned %d, %v", op, err)
   263  	}
   264  
   265  	br := bufio.NewReader(r)
   266  	p, err := br.ReadBytes('\n')
   267  	if err != nil {
   268  		t.Fatalf("ReadBytes() returned %v", err)
   269  	}
   270  	if len(p) != len(m) {
   271  		t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
   272  	}
   273  }
   274  
   275  var closeErrorTests = []struct {
   276  	err   error
   277  	codes []int
   278  	ok    bool
   279  }{
   280  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
   281  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
   282  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
   283  	{errors.New("hello"), []int{CloseNormalClosure}, false},
   284  }
   285  
   286  func TestCloseError(t *testing.T) {
   287  	for _, tt := range closeErrorTests {
   288  		ok := IsCloseError(tt.err, tt.codes...)
   289  		if ok != tt.ok {
   290  			t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
   291  		}
   292  	}
   293  }
   294  
   295  var unexpectedCloseErrorTests = []struct {
   296  	err   error
   297  	codes []int
   298  	ok    bool
   299  }{
   300  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
   301  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
   302  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
   303  	{errors.New("hello"), []int{CloseNormalClosure}, false},
   304  }
   305  
   306  func TestUnexpectedCloseErrors(t *testing.T) {
   307  	for _, tt := range unexpectedCloseErrorTests {
   308  		ok := IsUnexpectedCloseError(tt.err, tt.codes...)
   309  		if ok != tt.ok {
   310  			t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
   311  		}
   312  	}
   313  }
   314  
   315  type blockingWriter struct {
   316  	c1, c2 chan struct{}
   317  }
   318  
   319  func (w blockingWriter) Write(p []byte) (int, error) {
   320  	// Allow main to continue
   321  	close(w.c1)
   322  	// Wait for panic in main
   323  	<-w.c2
   324  	return len(p), nil
   325  }
   326  
   327  func TestConcurrentWritePanic(t *testing.T) {
   328  	w := blockingWriter{make(chan struct{}), make(chan struct{})}
   329  	c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
   330  	go func() {
   331  		c.WriteMessage(TextMessage, []byte{})
   332  	}()
   333  
   334  	// wait for goroutine to block in write.
   335  	<-w.c1
   336  
   337  	defer func() {
   338  		close(w.c2)
   339  		if v := recover(); v != nil {
   340  			return
   341  		}
   342  	}()
   343  
   344  	c.WriteMessage(TextMessage, []byte{})
   345  	t.Fatal("should not get here")
   346  }
   347  
   348  type failingReader struct{}
   349  
   350  func (r failingReader) Read(p []byte) (int, error) {
   351  	return 0, io.EOF
   352  }
   353  
   354  func TestFailedConnectionReadPanic(t *testing.T) {
   355  	c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
   356  
   357  	defer func() {
   358  		if v := recover(); v != nil {
   359  			return
   360  		}
   361  	}()
   362  
   363  	for i := 0; i < 20000; i++ {
   364  		c.ReadMessage()
   365  	}
   366  	t.Fatal("should not get here")
   367  }