github.com/philippseith/signalr@v0.6.3/loop.go (about)

     1  package signalr
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"runtime/debug"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  )
    13  
    14  type loop struct {
    15  	lastID       uint64 // Used with atomic: Must be first in struct to ensure 64bit alignment on 32bit architectures
    16  	party        Party
    17  	info         StructuredLogger
    18  	dbg          StructuredLogger
    19  	protocol     hubProtocol
    20  	hubConn      hubConnection
    21  	invokeClient *invokeClient
    22  	streamer     *streamer
    23  	streamClient *streamClient
    24  	closeMessage *closeMessage
    25  }
    26  
    27  func newLoop(p Party, conn Connection, protocol hubProtocol) *loop {
    28  	protocol = reflect.New(reflect.ValueOf(protocol).Elem().Type()).Interface().(hubProtocol)
    29  	_, dbg := p.loggers()
    30  	protocol.setDebugLogger(dbg)
    31  	pInfo, pDbg := p.prefixLoggers(conn.ConnectionID())
    32  	hubConn := newHubConnection(conn, protocol, p.maximumReceiveMessageSize(), pInfo)
    33  	return &loop{
    34  		party:        p,
    35  		protocol:     protocol,
    36  		hubConn:      hubConn,
    37  		invokeClient: newInvokeClient(protocol, p.chanReceiveTimeout()),
    38  		streamer:     &streamer{conn: hubConn},
    39  		streamClient: newStreamClient(protocol, p.chanReceiveTimeout(), p.streamBufferCapacity()),
    40  		info:         pInfo,
    41  		dbg:          pDbg,
    42  	}
    43  }
    44  
    45  // Run runs the loop. After the startup sequence is done, this is signaled over the started channel.
    46  // Callers should pass a channel with buffer size 1 to allow the loop to run without waiting for the caller.
    47  func (l *loop) Run(connected chan struct{}) (err error) {
    48  	l.party.onConnected(l.hubConn)
    49  	connected <- struct{}{}
    50  	close(connected)
    51  	// Process messages
    52  	ch := make(chan receiveResult, 1)
    53  	go func() {
    54  		recv := l.hubConn.Receive()
    55  	loop:
    56  		for {
    57  			select {
    58  			case result, ok := <-recv:
    59  				if !ok {
    60  					break loop
    61  				}
    62  				select {
    63  				case ch <- result:
    64  				case <-l.hubConn.Context().Done():
    65  					break loop
    66  				}
    67  			case <-l.hubConn.Context().Done():
    68  				break loop
    69  			}
    70  		}
    71  	}()
    72  	timeoutTicker := time.NewTicker(l.party.timeout())
    73  msgLoop:
    74  	for {
    75  	pingLoop:
    76  		for {
    77  			select {
    78  			case evt := <-ch:
    79  				err = evt.err
    80  				timeoutTicker.Reset(l.party.timeout())
    81  				if err == nil {
    82  					switch message := evt.message.(type) {
    83  					case invocationMessage:
    84  						l.handleInvocationMessage(message)
    85  					case cancelInvocationMessage:
    86  						_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message))
    87  						l.streamer.Stop(message.InvocationID)
    88  					case streamItemMessage:
    89  						err = l.handleStreamItemMessage(message)
    90  					case completionMessage:
    91  						err = l.handleCompletionMessage(message)
    92  					case closeMessage:
    93  						_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message))
    94  						l.closeMessage = &message
    95  						if message.Error != "" {
    96  							err = errors.New(message.Error)
    97  						}
    98  					case hubMessage:
    99  						// Mostly ping
   100  						err = l.handleOtherMessage(message)
   101  						// No default case necessary, because the protocol would return either a hubMessage or an error
   102  					}
   103  				} else {
   104  					_ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(evt.message), react, "close connection")
   105  				}
   106  				break pingLoop
   107  			case <-time.After(l.party.keepAliveInterval()):
   108  				// Send ping only when there was no write in the keepAliveInterval before
   109  				if time.Since(l.hubConn.LastWriteStamp()) > l.party.keepAliveInterval() {
   110  					_ = l.hubConn.Ping()
   111  				}
   112  				// Don't break the pingLoop when keepAlive is over, it exists for this case
   113  			case <-timeoutTicker.C:
   114  				err = fmt.Errorf("timeout interval elapsed (%v)", l.party.timeout())
   115  				break pingLoop
   116  			case <-l.hubConn.Context().Done():
   117  				err = fmt.Errorf("breaking loop. hubConnection canceled: %w", l.hubConn.Context().Err())
   118  				break pingLoop
   119  			case <-l.party.context().Done():
   120  				err = fmt.Errorf("breaking loop. Party canceled: %w", l.party.context().Err())
   121  				break pingLoop
   122  			}
   123  		}
   124  		if err != nil || l.closeMessage != nil {
   125  			break msgLoop
   126  		}
   127  	}
   128  	l.party.onDisconnected(l.hubConn)
   129  	if err != nil {
   130  		_ = l.hubConn.Close(fmt.Sprintf("%v", err), l.party.allowReconnect())
   131  	}
   132  	_ = l.dbg.Log(evt, "message loop ended")
   133  	l.invokeClient.cancelAllInvokes()
   134  	l.hubConn.Abort()
   135  	return err
   136  }
   137  
   138  func (l *loop) PullStream(method, id string, arguments ...interface{}) <-chan InvokeResult {
   139  	_, errChan := l.invokeClient.newInvocation(id)
   140  	upChan := l.streamClient.newUpstreamChannel(id)
   141  	ch := newInvokeResultChan(l.party.context(), upChan, errChan)
   142  	if err := l.hubConn.SendStreamInvocation(id, method, arguments); err != nil {
   143  		// When we get an error here, the loop is closed and the errChan might be already closed
   144  		// We create a new one to deliver our error
   145  		ch, _ = createResultChansWithError(l.party.context(), err)
   146  		l.streamClient.deleteUpstreamChannel(id)
   147  		l.invokeClient.deleteInvocation(id)
   148  	}
   149  	return ch
   150  }
   151  
   152  func (l *loop) PushStreams(method, id string, arguments ...interface{}) (<-chan InvokeResult, error) {
   153  	resultCh, errCh := l.invokeClient.newInvocation(id)
   154  	irCh := newInvokeResultChan(l.party.context(), resultCh, errCh)
   155  	invokeArgs := make([]interface{}, 0)
   156  	reflectedChannels := make([]reflect.Value, 0)
   157  	streamIds := make([]string, 0)
   158  	// Parse arguments for channels and other kind of arguments
   159  	for _, arg := range arguments {
   160  		if reflect.TypeOf(arg).Kind() == reflect.Chan {
   161  			reflectedChannels = append(reflectedChannels, reflect.ValueOf(arg))
   162  			streamIds = append(streamIds, l.GetNewID())
   163  		} else {
   164  			invokeArgs = append(invokeArgs, arg)
   165  		}
   166  	}
   167  	// Tell the server we are streaming now
   168  	if err := l.hubConn.SendInvocationWithStreamIds(id, method, invokeArgs, streamIds); err != nil {
   169  		l.invokeClient.deleteInvocation(id)
   170  		return nil, err
   171  	}
   172  	// Start streaming on all channels
   173  	for i, reflectedChannel := range reflectedChannels {
   174  		l.streamer.Start(streamIds[i], reflectedChannel)
   175  	}
   176  	return irCh, nil
   177  }
   178  
   179  // GetNewID returns a new, connection-unique id for invocations and streams
   180  func (l *loop) GetNewID() string {
   181  	atomic.AddUint64(&l.lastID, 1)
   182  	return fmt.Sprint(atomic.LoadUint64(&l.lastID))
   183  }
   184  
   185  func (l *loop) handleInvocationMessage(invocation invocationMessage) {
   186  	_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(invocation))
   187  	// Transient hub, dispatch invocation here
   188  	if method, ok := getMethod(l.party.invocationTarget(l.hubConn), invocation.Target); !ok {
   189  		// Unable to find the method
   190  		_ = l.info.Log(evt, "getMethod", "error", "missing method", "name", invocation.Target, react, "send completion with error")
   191  		_ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("Unknown method %s", invocation.Target))
   192  	} else if in, err := buildMethodArguments(method, invocation, l.streamClient, l.protocol); err != nil {
   193  		// argument build failed
   194  		_ = l.info.Log(evt, "buildMethodArguments", "error", err, "name", invocation.Target, react, "send completion with error")
   195  		_ = l.hubConn.Completion(invocation.InvocationID, nil, err.Error())
   196  	} else {
   197  		// Stream invocation is only allowed when the method has only one return value
   198  		// We allow no channel return values, because a client can receive as stream with only one item
   199  		if invocation.Type == 4 && method.Type().NumOut() != 1 {
   200  			_ = l.hubConn.Completion(invocation.InvocationID, nil,
   201  				fmt.Sprintf("Stream invocation of method %s which has not return value kind channel", invocation.Target))
   202  		} else {
   203  			// hub method might take a long time
   204  			go func() {
   205  				result := func() []reflect.Value {
   206  					defer l.recoverInvocationPanic(invocation)
   207  					return method.Call(in)
   208  				}()
   209  				l.returnInvocationResult(invocation, result)
   210  			}()
   211  		}
   212  	}
   213  }
   214  
   215  func (l *loop) returnInvocationResult(invocation invocationMessage, result []reflect.Value) {
   216  	// No invocation id, no completion
   217  	if invocation.InvocationID != "" {
   218  		// if the hub method returns a chan, it should be considered asynchronous or source for a stream
   219  		if len(result) == 1 && result[0].Kind() == reflect.Chan {
   220  			switch invocation.Type {
   221  			// Simple invocation
   222  			case 1:
   223  				go func() {
   224  					// Recv might block, so run continue in a goroutine
   225  					if chanResult, ok := result[0].Recv(); ok {
   226  						l.sendResult(invocation, completion, []reflect.Value{chanResult})
   227  					} else {
   228  
   229  						_ = l.hubConn.Completion(invocation.InvocationID, nil, "hub func returned closed chan")
   230  					}
   231  				}()
   232  			// StreamInvocation
   233  			case 4:
   234  				l.streamer.Start(invocation.InvocationID, result[0])
   235  			}
   236  		} else {
   237  			switch invocation.Type {
   238  			// Simple invocation
   239  			case 1:
   240  				l.sendResult(invocation, completion, result)
   241  			case 4:
   242  				// Stream invocation of method with no stream result.
   243  				// Return a single StreamItem and an empty Completion
   244  				l.sendResult(invocation, streamItem, result)
   245  				_ = l.hubConn.Completion(invocation.InvocationID, nil, "")
   246  			}
   247  		}
   248  	}
   249  }
   250  
   251  func (l *loop) handleStreamItemMessage(streamItemMessage streamItemMessage) error {
   252  	_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(streamItemMessage))
   253  	if err := l.streamClient.receiveStreamItem(streamItemMessage); err != nil {
   254  		switch t := err.(type) {
   255  		case *hubChanTimeoutError:
   256  			_ = l.hubConn.Completion(streamItemMessage.InvocationID, nil, t.Error())
   257  		default:
   258  			_ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(streamItemMessage), react, "close connection")
   259  			return err
   260  		}
   261  	}
   262  	return nil
   263  }
   264  
   265  func (l *loop) handleCompletionMessage(message completionMessage) error {
   266  	_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message))
   267  	var err error
   268  	if l.streamClient.handlesInvocationID(message.InvocationID) {
   269  		err = l.streamClient.receiveCompletionItem(message, l.invokeClient)
   270  	} else if l.invokeClient.handlesInvocationID(message.InvocationID) {
   271  		err = l.invokeClient.receiveCompletionItem(message)
   272  	} else {
   273  		err = fmt.Errorf("unknown invocationID %v", message.InvocationID)
   274  	}
   275  	if err != nil {
   276  		_ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(message), react, "close connection")
   277  	}
   278  	return err
   279  }
   280  
   281  func (l *loop) handleOtherMessage(hubMessage hubMessage) error {
   282  	_ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(hubMessage))
   283  	// Not Ping
   284  	if hubMessage.Type != 6 {
   285  		err := fmt.Errorf("invalid message type %v", hubMessage)
   286  		_ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(hubMessage), react, "close connection")
   287  		return err
   288  	}
   289  	return nil
   290  }
   291  
   292  func (l *loop) sendResult(invocation invocationMessage, connFunc connFunc, result []reflect.Value) {
   293  	values := make([]interface{}, len(result))
   294  	for i, rv := range result {
   295  		values[i] = rv.Interface()
   296  	}
   297  	switch len(result) {
   298  	case 0:
   299  		_ = l.hubConn.Completion(invocation.InvocationID, nil, "")
   300  	case 1:
   301  		connFunc(l, invocation, values[0])
   302  	default:
   303  		connFunc(l, invocation, values)
   304  	}
   305  }
   306  
   307  type connFunc func(sl *loop, invocation invocationMessage, value interface{})
   308  
   309  func completion(sl *loop, invocation invocationMessage, value interface{}) {
   310  	_ = sl.hubConn.Completion(invocation.InvocationID, value, "")
   311  }
   312  
   313  func streamItem(sl *loop, invocation invocationMessage, value interface{}) {
   314  
   315  	_ = sl.hubConn.StreamItem(invocation.InvocationID, value)
   316  }
   317  
   318  func (l *loop) recoverInvocationPanic(invocation invocationMessage) {
   319  	if err := recover(); err != nil {
   320  		_ = l.info.Log(evt, "panic in target method", "error", err, "name", invocation.Target, react, "send completion with error")
   321  		stack := string(debug.Stack())
   322  		_ = l.dbg.Log(evt, "panic in target method", "error", err, "name", invocation.Target, react, "send completion with error", "stack", stack)
   323  		if invocation.InvocationID != "" {
   324  			if !l.party.enableDetailedErrors() {
   325  				stack = ""
   326  			}
   327  			_ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("%v\n%v", err, stack))
   328  		}
   329  	}
   330  }
   331  
   332  func buildMethodArguments(method reflect.Value, invocation invocationMessage,
   333  	streamClient *streamClient, protocol hubProtocol) (arguments []reflect.Value, err error) {
   334  	if len(invocation.StreamIds)+len(invocation.Arguments) != method.Type().NumIn() {
   335  		return nil, fmt.Errorf("parameter mismatch calling method %v", invocation.Target)
   336  	}
   337  	arguments = make([]reflect.Value, method.Type().NumIn())
   338  	chanCount := 0
   339  	for i := 0; i < method.Type().NumIn(); i++ {
   340  		t := method.Type().In(i)
   341  		// Is it a channel for client streaming?
   342  		if arg, clientStreaming, err := streamClient.buildChannelArgument(invocation, t, chanCount); err != nil {
   343  			// it is, but channel count in invocation and method mismatch
   344  			return nil, err
   345  		} else if clientStreaming {
   346  			// it is
   347  			chanCount++
   348  			arguments[i] = arg
   349  		} else {
   350  			// it is not, so do the normal thing
   351  			arg := reflect.New(t)
   352  			if err := protocol.UnmarshalArgument(invocation.Arguments[i-chanCount], arg.Interface()); err != nil {
   353  				return arguments, err
   354  			}
   355  			arguments[i] = arg.Elem()
   356  		}
   357  	}
   358  	if len(invocation.StreamIds) != chanCount {
   359  		return arguments, fmt.Errorf("to many StreamIds for channel parameters of method %v", invocation.Target)
   360  	}
   361  	return arguments, nil
   362  }
   363  
   364  func getMethod(target interface{}, name string) (reflect.Value, bool) {
   365  	hubType := reflect.TypeOf(target)
   366  	if hubType != nil {
   367  		hubValue := reflect.ValueOf(target)
   368  		name = strings.ToLower(name)
   369  		for i := 0; i < hubType.NumMethod(); i++ {
   370  			// Search in public methods
   371  			if m := hubType.Method(i); strings.ToLower(m.Name) == name {
   372  				return hubValue.Method(i), true
   373  			}
   374  		}
   375  	}
   376  	return reflect.Value{}, false
   377  }
   378  
   379  func fmtMsg(message interface{}) string {
   380  	switch msg := message.(type) {
   381  	case invocationMessage:
   382  		fmtArgs := make([]interface{}, 0)
   383  		for _, arg := range msg.Arguments {
   384  			if rawArg, ok := arg.(json.RawMessage); ok {
   385  				fmtArgs = append(fmtArgs, string(rawArg))
   386  			} else {
   387  				fmtArgs = append(fmtArgs, arg)
   388  			}
   389  		}
   390  		msg.Arguments = fmtArgs
   391  		message = msg
   392  	}
   393  	return fmt.Sprintf("%#v", message)
   394  }