github.com/cnotch/ipchub@v1.1.0/network/websocket/websocket_test.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"net"
     7  	"net/http/httptest"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/gorilla/websocket"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  type writer bytes.Buffer
    16  
    17  func (w *writer) Close() error                         { return nil }
    18  func (w *writer) Write(data []byte) (n int, err error) { return ((*bytes.Buffer)(w)).Write(data) }
    19  
    20  type conn struct {
    21  	read  []byte
    22  	write *writer
    23  }
    24  
    25  func (c *conn) NextReader() (messageType int, r io.Reader, err error) {
    26  	messageType = websocket.BinaryMessage
    27  	r = bytes.NewBuffer(c.read)
    28  	if c.read == nil {
    29  		err = io.EOF
    30  	}
    31  	return
    32  }
    33  
    34  func (c *conn) NextWriter(messageType int) (w io.WriteCloser, err error) {
    35  	w = c.write
    36  	if c.write == nil {
    37  		err = io.EOF
    38  	}
    39  
    40  	return
    41  }
    42  func (c *conn) Close() error                       { return nil }
    43  func (c *conn) LocalAddr() net.Addr                { return &net.IPAddr{} }
    44  func (c *conn) RemoteAddr() net.Addr               { return &net.IPAddr{} }
    45  func (c *conn) SetReadDeadline(t time.Time) error  { return nil }
    46  func (c *conn) SetWriteDeadline(t time.Time) error { return nil }
    47  func (c *conn) Subprotocol() string                { return "" }
    48  func TestTryUpgradeNil(t *testing.T) {
    49  	_, ok := TryUpgrade(nil, nil, "", "")
    50  	assert.Equal(t, false, ok)
    51  }
    52  
    53  func TestTryUpgrade(t *testing.T) {
    54  	//httptest.NewServer(handler)
    55  	r := httptest.NewRequest("GET", "http://127.0.0.1/", bytes.NewBuffer([]byte{}))
    56  	r.Header.Set("Connection", "upgrade")
    57  	r.Header.Set("Upgrade", "websocket")
    58  	r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits")
    59  	r.Header.Set("Sec-WebSocket-Key", "D1icfJz+khA9kj5/14dRXQ==")
    60  	r.Header.Set("Sec-WebSocket-Protocol", "mqttv3.1")
    61  	r.Header.Set("Sec-WebSocket-Version", "13")
    62  
    63  	w := httptest.NewRecorder()
    64  
    65  	assert.NotPanics(t, func() {
    66  		TryUpgrade(w, r, "", "")
    67  	})
    68  
    69  	// TODO: need to have a hijackable response writer to test properly
    70  	//ws, ok := TryUpgrade(w, r)
    71  	//assert.NotNil(t, ws)
    72  	//assert.True(t, ok)
    73  }
    74  
    75  func TestRead_EOF(t *testing.T) {
    76  	c := newConn(new(conn), "", "")
    77  
    78  	_, err := c.Read([]byte{})
    79  	assert.Error(t, io.EOF, err)
    80  }
    81  
    82  func TestRead(t *testing.T) {
    83  	message := []byte("hello world")
    84  	c := &websocketTransport{
    85  		socket: &conn{
    86  			read: message,
    87  		},
    88  		closing: make(chan bool),
    89  	}
    90  
    91  	buffer := make([]byte, 64)
    92  	n, err := c.Read(buffer)
    93  	assert.NoError(t, err)
    94  	assert.Equal(t, message, buffer[:n])
    95  }
    96  
    97  func TestWrite(t *testing.T) {
    98  	message := []byte("hello world")
    99  	buffer := new(bytes.Buffer)
   100  	c := &websocketTransport{
   101  		socket: &conn{
   102  			write: (*writer)(buffer),
   103  		},
   104  		closing: make(chan bool),
   105  	}
   106  
   107  	_, err := c.Write(message)
   108  	assert.NoError(t, err)
   109  	assert.Equal(t, message, buffer.Bytes())
   110  }
   111  
   112  func TestMisc(t *testing.T) {
   113  	c := &websocketTransport{
   114  		socket:  &conn{},
   115  		closing: make(chan bool),
   116  	}
   117  
   118  	err := c.Close()
   119  	assert.NoError(t, err)
   120  
   121  	err = c.SetDeadline(time.Now())
   122  	assert.NoError(t, err)
   123  
   124  	err = c.SetReadDeadline(time.Now())
   125  	assert.NoError(t, err)
   126  
   127  	err = c.SetWriteDeadline(time.Now())
   128  	assert.NoError(t, err)
   129  
   130  	addr1 := c.LocalAddr()
   131  	assert.Equal(t, "", addr1.String())
   132  
   133  	addr2 := c.RemoteAddr()
   134  	assert.Equal(t, "", addr2.String())
   135  }