github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/callback_handler.go (about) 1 package bprotocol 2 3 import ( 4 "context" 5 "encoding/json" 6 "reflect" 7 8 "github.com/filecoin-project/bacalhau/pkg/compute" 9 "github.com/filecoin-project/bacalhau/pkg/logger" 10 "github.com/libp2p/go-libp2p/core/host" 11 "github.com/libp2p/go-libp2p/core/network" 12 "github.com/rs/zerolog/log" 13 ) 14 15 type CallbackHandlerParams struct { 16 Host host.Host 17 Callback compute.Callback 18 } 19 20 // CallbackHandler is a handler for callback events that registers for incoming libp2p requests to Bacalhau callback 21 // protocol, and delegates the handling of the request to the provided callback. 22 type CallbackHandler struct { 23 host host.Host 24 callback compute.Callback 25 } 26 27 func NewCallbackHandler(params CallbackHandlerParams) *CallbackHandler { 28 handler := &CallbackHandler{ 29 host: params.Host, 30 callback: params.Callback, 31 } 32 33 handler.host.SetStreamHandler(OnRunComplete, handler.onRunSuccess) 34 handler.host.SetStreamHandler(OnPublishComplete, handler.onPublishSuccess) 35 handler.host.SetStreamHandler(OnCancelComplete, handler.onCancelSuccess) 36 handler.host.SetStreamHandler(OnComputeFailure, handler.onComputeFailure) 37 return handler 38 } 39 40 func (h *CallbackHandler) onRunSuccess(stream network.Stream) { 41 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 42 handleCallbackStream[compute.RunResult](ctx, stream, h.callback.OnRunComplete) 43 } 44 func (h *CallbackHandler) onPublishSuccess(stream network.Stream) { 45 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 46 handleCallbackStream[compute.PublishResult](ctx, stream, h.callback.OnPublishComplete) 47 } 48 49 func (h *CallbackHandler) onCancelSuccess(stream network.Stream) { 50 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 51 handleCallbackStream[compute.CancelResult](ctx, stream, h.callback.OnCancelComplete) 52 } 53 54 func (h *CallbackHandler) onComputeFailure(stream network.Stream) { 55 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 56 handleCallbackStream[compute.ComputeError](ctx, stream, h.callback.OnComputeFailure) 57 } 58 59 //nolint:errcheck 60 func handleCallbackStream[Request any]( 61 ctx context.Context, 62 stream network.Stream, 63 f func(ctx context.Context, r Request)) { 64 ctx = logger.ContextWithNodeIDLogger(ctx, stream.Conn().LocalPeer().String()) 65 if err := stream.Scope().SetService(CallbackServiceName); err != nil { 66 log.Ctx(ctx).Error().Err(err).Msg("error attaching stream to requester service") 67 stream.Reset() 68 return 69 } 70 71 request := new(Request) 72 err := json.NewDecoder(stream).Decode(request) 73 if err != nil { 74 log.Ctx(ctx).Error().Msgf("error decoding %s: %s", reflect.TypeOf(request), err) 75 stream.Reset() 76 return 77 } 78 defer stream.Close() //nolint:errcheck 79 80 // TODO: validate which context to user here, and whether running in a goroutine is ok 81 newCtx := logger.ContextWithNodeIDLogger(context.Background(), stream.Conn().LocalPeer().String()) 82 go f(newCtx, *request) 83 }