github.com/philippseith/signalr@v0.6.3/loop.go (about) 1 package signalr 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "reflect" 8 "runtime/debug" 9 "strings" 10 "sync/atomic" 11 "time" 12 ) 13 14 type loop struct { 15 lastID uint64 // Used with atomic: Must be first in struct to ensure 64bit alignment on 32bit architectures 16 party Party 17 info StructuredLogger 18 dbg StructuredLogger 19 protocol hubProtocol 20 hubConn hubConnection 21 invokeClient *invokeClient 22 streamer *streamer 23 streamClient *streamClient 24 closeMessage *closeMessage 25 } 26 27 func newLoop(p Party, conn Connection, protocol hubProtocol) *loop { 28 protocol = reflect.New(reflect.ValueOf(protocol).Elem().Type()).Interface().(hubProtocol) 29 _, dbg := p.loggers() 30 protocol.setDebugLogger(dbg) 31 pInfo, pDbg := p.prefixLoggers(conn.ConnectionID()) 32 hubConn := newHubConnection(conn, protocol, p.maximumReceiveMessageSize(), pInfo) 33 return &loop{ 34 party: p, 35 protocol: protocol, 36 hubConn: hubConn, 37 invokeClient: newInvokeClient(protocol, p.chanReceiveTimeout()), 38 streamer: &streamer{conn: hubConn}, 39 streamClient: newStreamClient(protocol, p.chanReceiveTimeout(), p.streamBufferCapacity()), 40 info: pInfo, 41 dbg: pDbg, 42 } 43 } 44 45 // Run runs the loop. After the startup sequence is done, this is signaled over the started channel. 46 // Callers should pass a channel with buffer size 1 to allow the loop to run without waiting for the caller. 47 func (l *loop) Run(connected chan struct{}) (err error) { 48 l.party.onConnected(l.hubConn) 49 connected <- struct{}{} 50 close(connected) 51 // Process messages 52 ch := make(chan receiveResult, 1) 53 go func() { 54 recv := l.hubConn.Receive() 55 loop: 56 for { 57 select { 58 case result, ok := <-recv: 59 if !ok { 60 break loop 61 } 62 select { 63 case ch <- result: 64 case <-l.hubConn.Context().Done(): 65 break loop 66 } 67 case <-l.hubConn.Context().Done(): 68 break loop 69 } 70 } 71 }() 72 timeoutTicker := time.NewTicker(l.party.timeout()) 73 msgLoop: 74 for { 75 pingLoop: 76 for { 77 select { 78 case evt := <-ch: 79 err = evt.err 80 timeoutTicker.Reset(l.party.timeout()) 81 if err == nil { 82 switch message := evt.message.(type) { 83 case invocationMessage: 84 l.handleInvocationMessage(message) 85 case cancelInvocationMessage: 86 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message)) 87 l.streamer.Stop(message.InvocationID) 88 case streamItemMessage: 89 err = l.handleStreamItemMessage(message) 90 case completionMessage: 91 err = l.handleCompletionMessage(message) 92 case closeMessage: 93 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message)) 94 l.closeMessage = &message 95 if message.Error != "" { 96 err = errors.New(message.Error) 97 } 98 case hubMessage: 99 // Mostly ping 100 err = l.handleOtherMessage(message) 101 // No default case necessary, because the protocol would return either a hubMessage or an error 102 } 103 } else { 104 _ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(evt.message), react, "close connection") 105 } 106 break pingLoop 107 case <-time.After(l.party.keepAliveInterval()): 108 // Send ping only when there was no write in the keepAliveInterval before 109 if time.Since(l.hubConn.LastWriteStamp()) > l.party.keepAliveInterval() { 110 _ = l.hubConn.Ping() 111 } 112 // Don't break the pingLoop when keepAlive is over, it exists for this case 113 case <-timeoutTicker.C: 114 err = fmt.Errorf("timeout interval elapsed (%v)", l.party.timeout()) 115 break pingLoop 116 case <-l.hubConn.Context().Done(): 117 err = fmt.Errorf("breaking loop. hubConnection canceled: %w", l.hubConn.Context().Err()) 118 break pingLoop 119 case <-l.party.context().Done(): 120 err = fmt.Errorf("breaking loop. Party canceled: %w", l.party.context().Err()) 121 break pingLoop 122 } 123 } 124 if err != nil || l.closeMessage != nil { 125 break msgLoop 126 } 127 } 128 l.party.onDisconnected(l.hubConn) 129 if err != nil { 130 _ = l.hubConn.Close(fmt.Sprintf("%v", err), l.party.allowReconnect()) 131 } 132 _ = l.dbg.Log(evt, "message loop ended") 133 l.invokeClient.cancelAllInvokes() 134 l.hubConn.Abort() 135 return err 136 } 137 138 func (l *loop) PullStream(method, id string, arguments ...interface{}) <-chan InvokeResult { 139 _, errChan := l.invokeClient.newInvocation(id) 140 upChan := l.streamClient.newUpstreamChannel(id) 141 ch := newInvokeResultChan(l.party.context(), upChan, errChan) 142 if err := l.hubConn.SendStreamInvocation(id, method, arguments); err != nil { 143 // When we get an error here, the loop is closed and the errChan might be already closed 144 // We create a new one to deliver our error 145 ch, _ = createResultChansWithError(l.party.context(), err) 146 l.streamClient.deleteUpstreamChannel(id) 147 l.invokeClient.deleteInvocation(id) 148 } 149 return ch 150 } 151 152 func (l *loop) PushStreams(method, id string, arguments ...interface{}) (<-chan InvokeResult, error) { 153 resultCh, errCh := l.invokeClient.newInvocation(id) 154 irCh := newInvokeResultChan(l.party.context(), resultCh, errCh) 155 invokeArgs := make([]interface{}, 0) 156 reflectedChannels := make([]reflect.Value, 0) 157 streamIds := make([]string, 0) 158 // Parse arguments for channels and other kind of arguments 159 for _, arg := range arguments { 160 if reflect.TypeOf(arg).Kind() == reflect.Chan { 161 reflectedChannels = append(reflectedChannels, reflect.ValueOf(arg)) 162 streamIds = append(streamIds, l.GetNewID()) 163 } else { 164 invokeArgs = append(invokeArgs, arg) 165 } 166 } 167 // Tell the server we are streaming now 168 if err := l.hubConn.SendInvocationWithStreamIds(id, method, invokeArgs, streamIds); err != nil { 169 l.invokeClient.deleteInvocation(id) 170 return nil, err 171 } 172 // Start streaming on all channels 173 for i, reflectedChannel := range reflectedChannels { 174 l.streamer.Start(streamIds[i], reflectedChannel) 175 } 176 return irCh, nil 177 } 178 179 // GetNewID returns a new, connection-unique id for invocations and streams 180 func (l *loop) GetNewID() string { 181 atomic.AddUint64(&l.lastID, 1) 182 return fmt.Sprint(atomic.LoadUint64(&l.lastID)) 183 } 184 185 func (l *loop) handleInvocationMessage(invocation invocationMessage) { 186 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(invocation)) 187 // Transient hub, dispatch invocation here 188 if method, ok := getMethod(l.party.invocationTarget(l.hubConn), invocation.Target); !ok { 189 // Unable to find the method 190 _ = l.info.Log(evt, "getMethod", "error", "missing method", "name", invocation.Target, react, "send completion with error") 191 _ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("Unknown method %s", invocation.Target)) 192 } else if in, err := buildMethodArguments(method, invocation, l.streamClient, l.protocol); err != nil { 193 // argument build failed 194 _ = l.info.Log(evt, "buildMethodArguments", "error", err, "name", invocation.Target, react, "send completion with error") 195 _ = l.hubConn.Completion(invocation.InvocationID, nil, err.Error()) 196 } else { 197 // Stream invocation is only allowed when the method has only one return value 198 // We allow no channel return values, because a client can receive as stream with only one item 199 if invocation.Type == 4 && method.Type().NumOut() != 1 { 200 _ = l.hubConn.Completion(invocation.InvocationID, nil, 201 fmt.Sprintf("Stream invocation of method %s which has not return value kind channel", invocation.Target)) 202 } else { 203 // hub method might take a long time 204 go func() { 205 result := func() []reflect.Value { 206 defer l.recoverInvocationPanic(invocation) 207 return method.Call(in) 208 }() 209 l.returnInvocationResult(invocation, result) 210 }() 211 } 212 } 213 } 214 215 func (l *loop) returnInvocationResult(invocation invocationMessage, result []reflect.Value) { 216 // No invocation id, no completion 217 if invocation.InvocationID != "" { 218 // if the hub method returns a chan, it should be considered asynchronous or source for a stream 219 if len(result) == 1 && result[0].Kind() == reflect.Chan { 220 switch invocation.Type { 221 // Simple invocation 222 case 1: 223 go func() { 224 // Recv might block, so run continue in a goroutine 225 if chanResult, ok := result[0].Recv(); ok { 226 l.sendResult(invocation, completion, []reflect.Value{chanResult}) 227 } else { 228 229 _ = l.hubConn.Completion(invocation.InvocationID, nil, "hub func returned closed chan") 230 } 231 }() 232 // StreamInvocation 233 case 4: 234 l.streamer.Start(invocation.InvocationID, result[0]) 235 } 236 } else { 237 switch invocation.Type { 238 // Simple invocation 239 case 1: 240 l.sendResult(invocation, completion, result) 241 case 4: 242 // Stream invocation of method with no stream result. 243 // Return a single StreamItem and an empty Completion 244 l.sendResult(invocation, streamItem, result) 245 _ = l.hubConn.Completion(invocation.InvocationID, nil, "") 246 } 247 } 248 } 249 } 250 251 func (l *loop) handleStreamItemMessage(streamItemMessage streamItemMessage) error { 252 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(streamItemMessage)) 253 if err := l.streamClient.receiveStreamItem(streamItemMessage); err != nil { 254 switch t := err.(type) { 255 case *hubChanTimeoutError: 256 _ = l.hubConn.Completion(streamItemMessage.InvocationID, nil, t.Error()) 257 default: 258 _ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(streamItemMessage), react, "close connection") 259 return err 260 } 261 } 262 return nil 263 } 264 265 func (l *loop) handleCompletionMessage(message completionMessage) error { 266 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(message)) 267 var err error 268 if l.streamClient.handlesInvocationID(message.InvocationID) { 269 err = l.streamClient.receiveCompletionItem(message, l.invokeClient) 270 } else if l.invokeClient.handlesInvocationID(message.InvocationID) { 271 err = l.invokeClient.receiveCompletionItem(message) 272 } else { 273 err = fmt.Errorf("unknown invocationID %v", message.InvocationID) 274 } 275 if err != nil { 276 _ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(message), react, "close connection") 277 } 278 return err 279 } 280 281 func (l *loop) handleOtherMessage(hubMessage hubMessage) error { 282 _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(hubMessage)) 283 // Not Ping 284 if hubMessage.Type != 6 { 285 err := fmt.Errorf("invalid message type %v", hubMessage) 286 _ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(hubMessage), react, "close connection") 287 return err 288 } 289 return nil 290 } 291 292 func (l *loop) sendResult(invocation invocationMessage, connFunc connFunc, result []reflect.Value) { 293 values := make([]interface{}, len(result)) 294 for i, rv := range result { 295 values[i] = rv.Interface() 296 } 297 switch len(result) { 298 case 0: 299 _ = l.hubConn.Completion(invocation.InvocationID, nil, "") 300 case 1: 301 connFunc(l, invocation, values[0]) 302 default: 303 connFunc(l, invocation, values) 304 } 305 } 306 307 type connFunc func(sl *loop, invocation invocationMessage, value interface{}) 308 309 func completion(sl *loop, invocation invocationMessage, value interface{}) { 310 _ = sl.hubConn.Completion(invocation.InvocationID, value, "") 311 } 312 313 func streamItem(sl *loop, invocation invocationMessage, value interface{}) { 314 315 _ = sl.hubConn.StreamItem(invocation.InvocationID, value) 316 } 317 318 func (l *loop) recoverInvocationPanic(invocation invocationMessage) { 319 if err := recover(); err != nil { 320 _ = l.info.Log(evt, "panic in target method", "error", err, "name", invocation.Target, react, "send completion with error") 321 stack := string(debug.Stack()) 322 _ = l.dbg.Log(evt, "panic in target method", "error", err, "name", invocation.Target, react, "send completion with error", "stack", stack) 323 if invocation.InvocationID != "" { 324 if !l.party.enableDetailedErrors() { 325 stack = "" 326 } 327 _ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("%v\n%v", err, stack)) 328 } 329 } 330 } 331 332 func buildMethodArguments(method reflect.Value, invocation invocationMessage, 333 streamClient *streamClient, protocol hubProtocol) (arguments []reflect.Value, err error) { 334 if len(invocation.StreamIds)+len(invocation.Arguments) != method.Type().NumIn() { 335 return nil, fmt.Errorf("parameter mismatch calling method %v", invocation.Target) 336 } 337 arguments = make([]reflect.Value, method.Type().NumIn()) 338 chanCount := 0 339 for i := 0; i < method.Type().NumIn(); i++ { 340 t := method.Type().In(i) 341 // Is it a channel for client streaming? 342 if arg, clientStreaming, err := streamClient.buildChannelArgument(invocation, t, chanCount); err != nil { 343 // it is, but channel count in invocation and method mismatch 344 return nil, err 345 } else if clientStreaming { 346 // it is 347 chanCount++ 348 arguments[i] = arg 349 } else { 350 // it is not, so do the normal thing 351 arg := reflect.New(t) 352 if err := protocol.UnmarshalArgument(invocation.Arguments[i-chanCount], arg.Interface()); err != nil { 353 return arguments, err 354 } 355 arguments[i] = arg.Elem() 356 } 357 } 358 if len(invocation.StreamIds) != chanCount { 359 return arguments, fmt.Errorf("to many StreamIds for channel parameters of method %v", invocation.Target) 360 } 361 return arguments, nil 362 } 363 364 func getMethod(target interface{}, name string) (reflect.Value, bool) { 365 hubType := reflect.TypeOf(target) 366 if hubType != nil { 367 hubValue := reflect.ValueOf(target) 368 name = strings.ToLower(name) 369 for i := 0; i < hubType.NumMethod(); i++ { 370 // Search in public methods 371 if m := hubType.Method(i); strings.ToLower(m.Name) == name { 372 return hubValue.Method(i), true 373 } 374 } 375 } 376 return reflect.Value{}, false 377 } 378 379 func fmtMsg(message interface{}) string { 380 switch msg := message.(type) { 381 case invocationMessage: 382 fmtArgs := make([]interface{}, 0) 383 for _, arg := range msg.Arguments { 384 if rawArg, ok := arg.(json.RawMessage); ok { 385 fmtArgs = append(fmtArgs, string(rawArg)) 386 } else { 387 fmtArgs = append(fmtArgs, arg) 388 } 389 } 390 msg.Arguments = fmtArgs 391 message = msg 392 } 393 return fmt.Sprintf("%#v", message) 394 }