github.com/philippseith/signalr@v0.6.3/client.go (about)

     1  package signalr
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"reflect"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/cenkalti/backoff/v4"
    16  
    17  	"github.com/go-kit/log"
    18  )
    19  
    20  // ClientState is the state of the client.
    21  type ClientState int
    22  
    23  // Client states
    24  //
    25  //	ClientCreated
    26  //
    27  // The Client has been created and is not started yet.
    28  //
    29  //	ClientConnecting
    30  //
    31  // The Client has been started and is negotiating the connection.
    32  //
    33  //	ClientConnected
    34  //
    35  // The Client has successfully negotiated the connection and can send and receive messages.
    36  //
    37  //	ClientClosed
    38  //
    39  // The Client is not able to send and receive messages anymore and has to be started again to be able to.
    40  const (
    41  	ClientCreated ClientState = iota
    42  	ClientConnecting
    43  	ClientConnected
    44  	ClientClosed
    45  )
    46  
    47  // Client is the signalR connection used on the client side.
    48  //
    49  //	Start()
    50  //
    51  // Start starts the client loop. After starting the client, the interaction with a server can be started.
    52  // The client loop will run until the server closes the connection. If WithConnector is used, Start will
    53  // start a new loop. To end the loop from the client side, the context passed to NewClient has to be canceled
    54  // or the Stop function has to be called.
    55  //
    56  //	Stop()
    57  //
    58  // Stop stops the client loop. This is an alternative to using a cancelable context on NewClient.
    59  //
    60  //	State() ClientState
    61  //
    62  // State returns the current client state.
    63  // When WithConnector is set and the server allows reconnection, the client switches to ClientConnecting
    64  // and tries to reach ClientConnected after the last connection has ended.
    65  //
    66  //	ObserveStateChanged(chan ClientState) context.CancelFunc
    67  //
    68  // ObserveStateChanged pushes a new item != nil to the channel when State has changed.
    69  // The returned CancelFunc ends the observation and closes the channel.
    70  //
    71  //	Err() error
    72  //
    73  // Err returns the last error occurred while running the client.
    74  // When the client goes to ClientConnecting, Err is set to nil.
    75  //
    76  //	WaitForState(ctx context.Context, waitFor ClientState) <-chan error
    77  //
    78  // WaitForState returns a channel for waiting on the Client to reach a specific ClientState.
    79  // The channel either returns an error if ctx or the client has been canceled.
    80  // or nil if the ClientState waitFor was reached.
    81  //
    82  //	Invoke(method string, arguments ...interface{}) <-chan InvokeResult
    83  //
    84  // Invoke invokes a method on the server and returns a channel wich will return the InvokeResult.
    85  // When failing, InvokeResult.Error contains the client side error.
    86  //
    87  //	Send(method string, arguments ...interface{}) <-chan error
    88  //
    89  // Send invokes a method on the server but does not return a result from the server but only a channel,
    90  // which might contain a client side error occurred while sending.
    91  //
    92  //	PullStream(method string, arguments ...interface{}) <-chan InvokeResult
    93  //
    94  // PullStream invokes a streaming method on the server and returns a channel which delivers the stream items.
    95  // For more info about Streaming see https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#streaming
    96  //
    97  //	PushStreams(method string, arguments ...interface{}) <-chan error
    98  //
    99  // PushStreams pushes all items received from its arguments of type channel to the server (Upload Streaming).
   100  // PushStreams does not support server methods that return a channel.
   101  // For more info about Upload Streaming see https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#upload-streaming
   102  type Client interface {
   103  	Party
   104  	Start()
   105  	Stop()
   106  	State() ClientState
   107  	ObserveStateChanged(chan ClientState) context.CancelFunc
   108  	Err() error
   109  	WaitForState(ctx context.Context, waitFor ClientState) <-chan error
   110  	Invoke(method string, arguments ...interface{}) <-chan InvokeResult
   111  	Send(method string, arguments ...interface{}) <-chan error
   112  	PullStream(method string, arguments ...interface{}) <-chan InvokeResult
   113  	PushStreams(method string, arguments ...interface{}) <-chan InvokeResult
   114  }
   115  
   116  var ErrUnableToConnect = errors.New("neither WithConnection nor WithConnector option was given")
   117  
   118  // NewClient builds a new Client.
   119  // When ctx is canceled, the client loop and a possible auto reconnect loop are ended.
   120  func NewClient(ctx context.Context, options ...func(Party) error) (Client, error) {
   121  	var cancelFunc context.CancelFunc
   122  	ctx, cancelFunc = context.WithCancel(ctx)
   123  	info, dbg := buildInfoDebugLogger(log.NewLogfmtLogger(os.Stderr), true)
   124  	c := &client{
   125  		state:            ClientCreated,
   126  		stateChangeChans: make([]chan ClientState, 0),
   127  		format:           "json",
   128  		partyBase:        newPartyBase(ctx, info, dbg),
   129  		lastID:           -1,
   130  		backoffFactory:   func() backoff.BackOff { return backoff.NewExponentialBackOff() },
   131  		cancelFunc:       cancelFunc,
   132  	}
   133  	for _, option := range options {
   134  		if option != nil {
   135  			if err := option(c); err != nil {
   136  				return nil, err
   137  			}
   138  		}
   139  	}
   140  	// Wrap logging with timestamps
   141  	info, dbg = c.loggers()
   142  	c.setLoggers(
   143  		log.WithPrefix(info, "ts", log.DefaultTimestampUTC),
   144  		log.WithPrefix(dbg, "ts", log.DefaultTimestampUTC),
   145  	)
   146  	if c.conn == nil && c.connectionFactory == nil {
   147  		return nil, ErrUnableToConnect
   148  	}
   149  	return c, nil
   150  }
   151  
   152  type client struct {
   153  	partyBase
   154  	mx                sync.RWMutex
   155  	conn              Connection
   156  	connectionFactory func() (Connection, error)
   157  	state             ClientState
   158  	stateChangeChans  []chan ClientState
   159  	err               error
   160  	format            string
   161  	loop              *loop
   162  	receiver          interface{}
   163  	lastID            int64
   164  	backoffFactory    func() backoff.BackOff
   165  	cancelFunc        context.CancelFunc
   166  }
   167  
   168  func (c *client) Start() {
   169  	c.setState(ClientConnecting)
   170  	boff := c.backoffFactory()
   171  	go func() {
   172  		for {
   173  			c.setErr(nil)
   174  			// Listen for state change to ClientConnected and signal backoff Reset then.
   175  			stateChangeChan := make(chan ClientState, 1)
   176  			var connected atomic.Value
   177  			connected.Store(false)
   178  			cancelObserve := c.ObserveStateChanged(stateChangeChan)
   179  			go func() {
   180  				for range stateChangeChan {
   181  					if c.State() == ClientConnected {
   182  						connected.Store(true)
   183  						return
   184  					}
   185  				}
   186  			}()
   187  			// RUN!
   188  			err := c.run()
   189  			if err != nil {
   190  				_ = c.info.Log("connect", fmt.Sprintf("%v", err))
   191  				c.setErr(err)
   192  			}
   193  			shouldEnd := c.shouldClientEnd()
   194  			cancelObserve()
   195  			if shouldEnd {
   196  				return
   197  			}
   198  
   199  			// When the client has connected, BackOff should be reset
   200  			if connected.Load().(bool) {
   201  				boff.Reset()
   202  			}
   203  			// Reconnect after BackOff
   204  			nextBackoff := boff.NextBackOff()
   205  			// Check for exceeded backoff
   206  			if nextBackoff == backoff.Stop {
   207  				c.setErr(errors.New("backoff exceeded"))
   208  				return
   209  			}
   210  			select {
   211  			case <-time.After(nextBackoff):
   212  			case <-c.ctx.Done():
   213  				return
   214  			}
   215  			c.setState(ClientConnecting)
   216  		}
   217  	}()
   218  }
   219  
   220  func (c *client) Stop() {
   221  	if c.cancelFunc != nil {
   222  		c.cancelFunc()
   223  	}
   224  	c.setState(ClientClosed)
   225  }
   226  
   227  func (c *client) run() error {
   228  	// negotiate and so on
   229  	protocol, err := c.setupConnectionAndProtocol()
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	loop := newLoop(c, c.conn, protocol)
   235  	c.mx.Lock()
   236  	c.loop = loop
   237  	c.mx.Unlock()
   238  	// Broadcast when loop is connected
   239  	isLoopConnected := make(chan struct{}, 1)
   240  	go func() {
   241  		<-isLoopConnected
   242  		c.setState(ClientConnected)
   243  	}()
   244  	// Run the loop
   245  	err = loop.Run(isLoopConnected)
   246  
   247  	if err == nil {
   248  		err = loop.hubConn.Close("", false) // allowReconnect value is ignored as servers never initiate a connection
   249  	}
   250  
   251  	// Reset conn to allow reconnecting
   252  	c.mx.Lock()
   253  	c.conn = nil
   254  	c.mx.Unlock()
   255  
   256  	return err
   257  }
   258  
   259  func (c *client) shouldClientEnd() bool {
   260  	// Canceled?
   261  	if c.ctx.Err() != nil {
   262  		c.setErr(c.ctx.Err())
   263  		c.setState(ClientClosed)
   264  		return true
   265  	}
   266  	// Reconnecting not possible
   267  	if c.connectionFactory == nil {
   268  		c.setState(ClientClosed)
   269  		return true
   270  	}
   271  	// Reconnecting not allowed
   272  	if c.loop != nil && c.loop.closeMessage != nil && !c.loop.closeMessage.AllowReconnect {
   273  		c.setState(ClientClosed)
   274  		return true
   275  	}
   276  	return false
   277  }
   278  
   279  func (c *client) setupConnectionAndProtocol() (hubProtocol, error) {
   280  	return func() (hubProtocol, error) {
   281  		c.mx.Lock()
   282  		defer c.mx.Unlock()
   283  
   284  		if c.conn == nil {
   285  			if c.connectionFactory == nil {
   286  				return nil, ErrUnableToConnect
   287  			}
   288  			var err error
   289  			c.conn, err = c.connectionFactory()
   290  			if err != nil {
   291  				return nil, err
   292  			}
   293  		}
   294  		// Pass maximum receive message size to a potential websocket connection
   295  		if wsConn, ok := c.conn.(*webSocketConnection); ok {
   296  			wsConn.conn.SetReadLimit(int64(c.maximumReceiveMessageSize()))
   297  		}
   298  		protocol, err := c.processHandshake()
   299  		if err != nil {
   300  			return nil, err
   301  		}
   302  
   303  		return protocol, nil
   304  	}()
   305  }
   306  
   307  func (c *client) State() ClientState {
   308  	c.mx.RLock()
   309  	defer c.mx.RUnlock()
   310  	return c.state
   311  }
   312  
   313  func (c *client) setState(state ClientState) {
   314  	c.mx.Lock()
   315  	defer c.mx.Unlock()
   316  
   317  	c.state = state
   318  	_ = c.dbg.Log("state", state)
   319  
   320  	for _, ch := range c.stateChangeChans {
   321  		go func(ch chan ClientState, state ClientState) {
   322  			c.mx.Lock()
   323  			defer c.mx.Unlock()
   324  
   325  			for _, cch := range c.stateChangeChans {
   326  				if cch == ch {
   327  					select {
   328  					case ch <- state:
   329  					case <-c.ctx.Done():
   330  					}
   331  				}
   332  			}
   333  		}(ch, state)
   334  	}
   335  }
   336  
   337  func (c *client) ObserveStateChanged(ch chan ClientState) context.CancelFunc {
   338  	c.mx.Lock()
   339  	defer c.mx.Unlock()
   340  
   341  	c.stateChangeChans = append(c.stateChangeChans, ch)
   342  
   343  	return func() {
   344  		c.cancelObserveStateChanged(ch)
   345  	}
   346  }
   347  
   348  func (c *client) cancelObserveStateChanged(ch chan ClientState) {
   349  	c.mx.Lock()
   350  	defer c.mx.Unlock()
   351  	for i, cch := range c.stateChangeChans {
   352  		if cch == ch {
   353  			c.stateChangeChans = append(c.stateChangeChans[:i], c.stateChangeChans[i+1:]...)
   354  			close(ch)
   355  			break
   356  		}
   357  	}
   358  }
   359  
   360  func (c *client) Err() error {
   361  	c.mx.RLock()
   362  	defer c.mx.RUnlock()
   363  	return c.err
   364  }
   365  
   366  func (c *client) setErr(err error) {
   367  	c.mx.Lock()
   368  	defer c.mx.Unlock()
   369  	c.err = err
   370  }
   371  
   372  func (c *client) WaitForState(ctx context.Context, waitFor ClientState) <-chan error {
   373  	ch := make(chan error, 1)
   374  	if c.waitingIsOver(waitFor, ch) {
   375  		close(ch)
   376  		return ch
   377  	}
   378  	stateCh := make(chan ClientState, 1)
   379  	cancel := c.ObserveStateChanged(stateCh)
   380  	go func(waitFor ClientState) {
   381  		defer close(ch)
   382  		defer cancel()
   383  		if c.waitingIsOver(waitFor, ch) {
   384  			return
   385  		}
   386  		for {
   387  			select {
   388  			case <-stateCh:
   389  				if c.waitingIsOver(waitFor, ch) {
   390  					return
   391  				}
   392  			case <-ctx.Done():
   393  				ch <- ctx.Err()
   394  				return
   395  			case <-c.context().Done():
   396  				ch <- fmt.Errorf("client canceled: %w", c.context().Err())
   397  				return
   398  			}
   399  		}
   400  	}(waitFor)
   401  	return ch
   402  }
   403  
   404  func (c *client) waitingIsOver(waitFor ClientState, ch chan<- error) bool {
   405  	switch c.State() {
   406  	case waitFor:
   407  		return true
   408  	case ClientCreated:
   409  		ch <- errors.New("client not started. Call client.Start() before using it")
   410  		return true
   411  	case ClientClosed:
   412  		ch <- fmt.Errorf("client closed. %w", c.Err())
   413  		return true
   414  	}
   415  	return false
   416  }
   417  
   418  func (c *client) Invoke(method string, arguments ...interface{}) <-chan InvokeResult {
   419  	ch := make(chan InvokeResult, 1)
   420  	go func() {
   421  
   422  		if err := <-c.waitForConnected(); err != nil {
   423  			ch <- InvokeResult{Error: err}
   424  			close(ch)
   425  			return
   426  		}
   427  		id := c.loop.GetNewID()
   428  		resultCh, errCh := c.loop.invokeClient.newInvocation(id)
   429  		irCh := newInvokeResultChan(c.context(), resultCh, errCh)
   430  		if err := c.loop.hubConn.SendInvocation(id, method, arguments); err != nil {
   431  			c.loop.invokeClient.deleteInvocation(id)
   432  			ch <- InvokeResult{Error: err}
   433  			close(ch)
   434  			return
   435  		}
   436  		go func() {
   437  			for ir := range irCh {
   438  				ch <- ir
   439  			}
   440  			close(ch)
   441  		}()
   442  	}()
   443  	return ch
   444  }
   445  
   446  func (c *client) Send(method string, arguments ...interface{}) <-chan error {
   447  	errCh := make(chan error, 1)
   448  	go func() {
   449  		if err := <-c.waitForConnected(); err != nil {
   450  			errCh <- err
   451  			close(errCh)
   452  			return
   453  		}
   454  		id := c.loop.GetNewID()
   455  		_, sendErrCh := c.loop.invokeClient.newInvocation(id)
   456  		if err := c.loop.hubConn.SendInvocation(id, method, arguments); err != nil {
   457  			c.loop.invokeClient.deleteInvocation(id)
   458  			errCh <- err
   459  			close(errCh)
   460  			return
   461  		}
   462  		go func() {
   463  			for ir := range sendErrCh {
   464  				errCh <- ir
   465  			}
   466  			close(errCh)
   467  		}()
   468  	}()
   469  	return errCh
   470  }
   471  
   472  func (c *client) PullStream(method string, arguments ...interface{}) <-chan InvokeResult {
   473  	irCh := make(chan InvokeResult, 1)
   474  	go func() {
   475  		if err := <-c.waitForConnected(); err != nil {
   476  			irCh <- InvokeResult{Error: err}
   477  			close(irCh)
   478  			return
   479  		}
   480  		pullCh := c.loop.PullStream(method, c.loop.GetNewID(), arguments...)
   481  		go func() {
   482  			for ir := range pullCh {
   483  				irCh <- ir
   484  				if ir.Error != nil {
   485  					break
   486  				}
   487  			}
   488  			close(irCh)
   489  		}()
   490  	}()
   491  	return irCh
   492  }
   493  
   494  func (c *client) PushStreams(method string, arguments ...interface{}) <-chan InvokeResult {
   495  	irCh := make(chan InvokeResult, 1)
   496  	go func() {
   497  		if err := <-c.waitForConnected(); err != nil {
   498  			irCh <- InvokeResult{Error: err}
   499  			close(irCh)
   500  			return
   501  		}
   502  		pushCh, err := c.loop.PushStreams(method, c.loop.GetNewID(), arguments...)
   503  		if err != nil {
   504  			irCh <- InvokeResult{Error: err}
   505  			close(irCh)
   506  			return
   507  		}
   508  		go func() {
   509  			for ir := range pushCh {
   510  				irCh <- ir
   511  			}
   512  			close(irCh)
   513  		}()
   514  	}()
   515  	return irCh
   516  }
   517  
   518  func (c *client) waitForConnected() <-chan error {
   519  	return c.WaitForState(context.Background(), ClientConnected)
   520  }
   521  
   522  func createResultChansWithError(ctx context.Context, err error) (<-chan InvokeResult, chan error) {
   523  	resultCh := make(chan interface{}, 1)
   524  	errCh := make(chan error, 1)
   525  	errCh <- err
   526  	invokeResultChan := newInvokeResultChan(ctx, resultCh, errCh)
   527  	close(errCh)
   528  	close(resultCh)
   529  	return invokeResultChan, errCh
   530  }
   531  
   532  func (c *client) onConnected(hubConnection) {}
   533  
   534  func (c *client) onDisconnected(hubConnection) {}
   535  
   536  func (c *client) invocationTarget(hubConnection) interface{} {
   537  	return c.receiver
   538  }
   539  
   540  func (c *client) allowReconnect() bool {
   541  	return false // Servers don't care?
   542  }
   543  
   544  func (c *client) prefixLoggers(connectionID string) (info StructuredLogger, dbg StructuredLogger) {
   545  	if c.receiver == nil {
   546  		return log.WithPrefix(c.info, "ts", log.DefaultTimestampUTC, "class", "Client", "connection", connectionID),
   547  			log.WithPrefix(c.dbg, "ts", log.DefaultTimestampUTC, "class", "Client", "connection", connectionID)
   548  	}
   549  	var t reflect.Type = nil
   550  	switch reflect.ValueOf(c.receiver).Kind() {
   551  	case reflect.Ptr:
   552  		t = reflect.ValueOf(c.receiver).Elem().Type()
   553  	case reflect.Struct:
   554  		t = reflect.ValueOf(c.receiver).Type()
   555  	}
   556  	return log.WithPrefix(c.info, "ts", log.DefaultTimestampUTC,
   557  			"class", "Client",
   558  			"connection", connectionID,
   559  			"hub", t),
   560  		log.WithPrefix(c.dbg, "ts", log.DefaultTimestampUTC,
   561  			"class", "Client",
   562  			"connection", connectionID,
   563  			"hub", t)
   564  }
   565  
   566  func (c *client) processHandshake() (hubProtocol, error) {
   567  	if err := c.sendHandshakeRequest(); err != nil {
   568  		return nil, err
   569  	}
   570  	return c.receiveHandshakeResponse()
   571  }
   572  
   573  func (c *client) sendHandshakeRequest() error {
   574  	info, dbg := c.prefixLoggers(c.conn.ConnectionID())
   575  	request := fmt.Sprintf("{\"protocol\":\"%v\",\"version\":1}\u001e", c.format)
   576  	ctx, cancelWrite := context.WithTimeout(c.context(), c.HandshakeTimeout())
   577  	defer cancelWrite()
   578  	_, err := ReadWriteWithContext(ctx,
   579  		func() (int, error) {
   580  			return c.conn.Write([]byte(request))
   581  		}, func() {})
   582  	if err != nil {
   583  		_ = info.Log(evt, "handshake sent", "msg", request, "error", err)
   584  		return err
   585  	}
   586  	_ = dbg.Log(evt, "handshake sent", "msg", request)
   587  	return nil
   588  }
   589  
   590  func (c *client) receiveHandshakeResponse() (hubProtocol, error) {
   591  	info, dbg := c.prefixLoggers(c.conn.ConnectionID())
   592  	ctx, cancelRead := context.WithTimeout(c.context(), c.HandshakeTimeout())
   593  	defer cancelRead()
   594  	readJSONFramesChan := make(chan []interface{}, 1)
   595  	go func() {
   596  		var remainBuf bytes.Buffer
   597  		rawHandshake, err := readJSONFrames(c.conn, &remainBuf)
   598  		readJSONFramesChan <- []interface{}{rawHandshake, err}
   599  	}()
   600  	select {
   601  	case result := <-readJSONFramesChan:
   602  		if result[1] != nil {
   603  			return nil, result[1].(error)
   604  		}
   605  		rawHandshake := result[0].([][]byte)
   606  		response := handshakeResponse{}
   607  		if err := json.Unmarshal(rawHandshake[0], &response); err != nil {
   608  			// Malformed handshake
   609  			_ = info.Log(evt, "handshake received", "msg", string(rawHandshake[0]), "error", err)
   610  			return nil, err
   611  		} else {
   612  			if response.Error != "" {
   613  				_ = info.Log(evt, "handshake received", "error", response.Error)
   614  				return nil, errors.New(response.Error)
   615  			}
   616  			_ = dbg.Log(evt, "handshake received", "msg", fmtMsg(response))
   617  			var protocol hubProtocol
   618  			switch c.format {
   619  			case "json":
   620  				protocol = &jsonHubProtocol{}
   621  			case "messagepack":
   622  				protocol = &messagePackHubProtocol{}
   623  			}
   624  			if protocol != nil {
   625  				_, pDbg := c.loggers()
   626  				protocol.setDebugLogger(pDbg)
   627  			}
   628  			return protocol, nil
   629  		}
   630  	case <-ctx.Done():
   631  		return nil, ctx.Err()
   632  	}
   633  }