github.com/philippseith/signalr@v0.6.3/testingconnection_test.go (about)

     1  package signalr
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/onsi/ginkgo"
    14  )
    15  
    16  type testingConnection struct {
    17  	timeout      time.Duration
    18  	connectionID string
    19  	srvWriter    io.Writer
    20  	srvReader    io.Reader
    21  	cliWriter    io.Writer
    22  	cliReader    io.Reader
    23  	received     chan interface{}
    24  	cnMutex      sync.Mutex
    25  	connected    bool
    26  	cliSendChan  chan string
    27  	srvSendChan  chan []byte
    28  	failRead     string
    29  	failWrite    string
    30  	failMx       sync.Mutex
    31  }
    32  
    33  func (t *testingConnection) Context() context.Context {
    34  	return context.TODO()
    35  }
    36  
    37  var connNum = 0
    38  var connNumMx sync.Mutex
    39  
    40  func (t *testingConnection) SetTimeout(timeout time.Duration) {
    41  	t.timeout = timeout
    42  }
    43  
    44  func (t *testingConnection) Timeout() time.Duration {
    45  	return t.timeout
    46  }
    47  
    48  func (t *testingConnection) ConnectionID() string {
    49  	connNumMx.Lock()
    50  	defer connNumMx.Unlock()
    51  	if t.connectionID == "" {
    52  		connNum++
    53  		t.connectionID = fmt.Sprintf("test%v", connNum)
    54  	}
    55  	return t.connectionID
    56  }
    57  
    58  func (t *testingConnection) SetConnectionID(id string) {
    59  	t.connectionID = id
    60  }
    61  
    62  func (t *testingConnection) Read(b []byte) (n int, err error) {
    63  	if fr := t.FailRead(); fr != "" {
    64  		defer func() { t.SetFailRead("") }()
    65  		return 0, errors.New(fr)
    66  	}
    67  	timer := make(<-chan time.Time)
    68  	if t.Timeout() > 0 {
    69  		timer = time.After(t.Timeout())
    70  	}
    71  	nch := make(chan int)
    72  	go func() {
    73  		n, _ := t.srvReader.Read(b)
    74  		nch <- n
    75  	}()
    76  	select {
    77  	case n := <-nch:
    78  		return n, nil
    79  	case <-timer:
    80  		return 0, fmt.Errorf("timeout %v", t.Timeout())
    81  	}
    82  }
    83  
    84  func (t *testingConnection) Write(b []byte) (n int, err error) {
    85  	if fw := t.FailWrite(); fw != "" {
    86  		defer func() { t.SetFailWrite("") }()
    87  		return 0, errors.New(fw)
    88  	}
    89  	t.srvSendChan <- b
    90  	return len(b), nil
    91  }
    92  
    93  func (t *testingConnection) Connected() bool {
    94  	t.cnMutex.Lock()
    95  	defer t.cnMutex.Unlock()
    96  	return t.connected
    97  }
    98  
    99  func (t *testingConnection) SetConnected(connected bool) {
   100  	t.cnMutex.Lock()
   101  	defer t.cnMutex.Unlock()
   102  	t.connected = connected
   103  }
   104  
   105  func (t *testingConnection) FailRead() string {
   106  	defer t.failMx.Unlock()
   107  	t.failMx.Lock()
   108  	return t.failRead
   109  }
   110  
   111  func (t *testingConnection) FailWrite() string {
   112  	defer t.failMx.Unlock()
   113  	t.failMx.Lock()
   114  	return t.failWrite
   115  }
   116  
   117  func (t *testingConnection) SetFailRead(fail string) {
   118  	defer t.failMx.Unlock()
   119  	t.failMx.Lock()
   120  	t.failRead = fail
   121  }
   122  
   123  func (t *testingConnection) SetFailWrite(fail string) {
   124  	defer t.failMx.Unlock()
   125  	t.failMx.Lock()
   126  	t.failWrite = fail
   127  }
   128  
   129  // newTestingConnectionForServer builds a testingConnection with an sent (but not yet received) handshake for testing a server
   130  func newTestingConnectionForServer() *testingConnection {
   131  	conn := newTestingConnection()
   132  	// client receive loop
   133  	go receiveLoop(conn)()
   134  	// Send initial Handshake
   135  	conn.ClientSend(`{"protocol": "json","version": 1}`)
   136  	conn.SetConnected(true)
   137  	return conn
   138  }
   139  
   140  func newTestingConnection() *testingConnection {
   141  	cliReader, srvWriter := io.Pipe()
   142  	srvReader, cliWriter := io.Pipe()
   143  	conn := testingConnection{
   144  		srvWriter:   srvWriter,
   145  		srvReader:   srvReader,
   146  		cliWriter:   cliWriter,
   147  		cliReader:   cliReader,
   148  		received:    make(chan interface{}, 20),
   149  		cliSendChan: make(chan string, 20),
   150  		srvSendChan: make(chan []byte, 20),
   151  		timeout:     time.Second * 5,
   152  	}
   153  	// client send loop
   154  	go func() {
   155  		for {
   156  			_, _ = conn.cliWriter.Write(append([]byte(<-conn.cliSendChan), 30))
   157  		}
   158  	}()
   159  	// server send loop
   160  	go func() {
   161  		for {
   162  			_, _ = conn.srvWriter.Write(<-conn.srvSendChan)
   163  		}
   164  	}()
   165  	return &conn
   166  }
   167  
   168  func (t *testingConnection) ClientSend(message string) {
   169  	t.cliSendChan <- message
   170  }
   171  
   172  func (t *testingConnection) ClientReceive() (string, error) {
   173  	var buf bytes.Buffer
   174  	var data = make([]byte, 1<<15) // 32K
   175  	var nn int
   176  	for {
   177  		if message, err := buf.ReadString(30); err != nil {
   178  			buf.Write(data[:nn])
   179  			if n, err := t.cliReader.Read(data[nn:]); err == nil {
   180  				buf.Write(data[nn : nn+n])
   181  				nn = nn + n
   182  			} else {
   183  				return "", err
   184  			}
   185  		} else {
   186  			return message[:len(message)-1], nil
   187  		}
   188  	}
   189  }
   190  
   191  func (t *testingConnection) ReceiveChan() chan interface{} {
   192  	return t.received
   193  }
   194  
   195  type clientReceiver interface {
   196  	ClientReceive() (string, error)
   197  	ReceiveChan() chan interface{}
   198  	SetConnected(bool)
   199  }
   200  
   201  func receiveLoop(conn clientReceiver) func() {
   202  	return func() {
   203  		defer ginkgo.GinkgoRecover()
   204  		errorHandler := func(err error) { ginkgo.Fail(fmt.Sprintf("received invalid message from server %v", err.Error())) }
   205  		for {
   206  			if message, err := conn.ClientReceive(); err == nil {
   207  				var hubMessage hubMessage
   208  				if err = json.Unmarshal([]byte(message), &hubMessage); err == nil {
   209  					switch hubMessage.Type {
   210  					case 1, 4:
   211  						var invocationMessage invocationMessage
   212  						if err = json.Unmarshal([]byte(message), &invocationMessage); err == nil {
   213  							conn.ReceiveChan() <- invocationMessage
   214  						} else {
   215  							errorHandler(err)
   216  						}
   217  					case 2:
   218  						var jsonStreamItemMessage jsonStreamItemMessage
   219  						if err = json.Unmarshal([]byte(message), &jsonStreamItemMessage); err == nil {
   220  
   221  							conn.ReceiveChan() <- streamItemMessage{
   222  								Type:         jsonStreamItemMessage.Type,
   223  								InvocationID: jsonStreamItemMessage.InvocationID,
   224  								Item:         jsonStreamItemMessage.Item,
   225  							}
   226  						} else {
   227  							errorHandler(err)
   228  						}
   229  					case 3:
   230  						var completionMessage completionMessage
   231  						if err = json.Unmarshal([]byte(message), &completionMessage); err == nil {
   232  							conn.ReceiveChan() <- completionMessage
   233  						} else {
   234  							errorHandler(err)
   235  						}
   236  					case 7:
   237  						var closeMessage closeMessage
   238  						if err = json.Unmarshal([]byte(message), &closeMessage); err == nil {
   239  							conn.SetConnected(false)
   240  							conn.ReceiveChan() <- closeMessage
   241  						} else {
   242  							errorHandler(err)
   243  						}
   244  					}
   245  				} else {
   246  					errorHandler(err)
   247  				}
   248  			}
   249  		}
   250  	}
   251  }