github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/compute_handler.go (about) 1 package bprotocol 2 3 import ( 4 "context" 5 "encoding/json" 6 "reflect" 7 8 "github.com/filecoin-project/bacalhau/pkg/compute" 9 "github.com/filecoin-project/bacalhau/pkg/logger" 10 "github.com/libp2p/go-libp2p/core/host" 11 "github.com/libp2p/go-libp2p/core/network" 12 "github.com/rs/zerolog/log" 13 ) 14 15 type ComputeHandlerParams struct { 16 Host host.Host 17 ComputeEndpoint compute.Endpoint 18 } 19 20 // ComputeHandler is a handler for compute requests that registers for incoming libp2p requests to Bacalhau compute 21 // protocol, and delegates the requests to the compute endpoint. 22 type ComputeHandler struct { 23 host host.Host 24 computeEndpoint compute.Endpoint 25 } 26 27 func NewComputeHandler(params ComputeHandlerParams) *ComputeHandler { 28 handler := &ComputeHandler{ 29 host: params.Host, 30 computeEndpoint: params.ComputeEndpoint, 31 } 32 33 handler.host.SetStreamHandler(AskForBidProtocolID, handler.onAskForBid) 34 handler.host.SetStreamHandler(BidAcceptedProtocolID, handler.onBidAccepted) 35 handler.host.SetStreamHandler(BidRejectedProtocolID, handler.onBidRejected) 36 handler.host.SetStreamHandler(ResultAcceptedProtocolID, handler.onResultAccepted) 37 handler.host.SetStreamHandler(ResultRejectedProtocolID, handler.onResultRejected) 38 handler.host.SetStreamHandler(CancelProtocolID, handler.onCancelJob) 39 log.Debug().Msgf("ComputeHandler started on host %s", handler.host.ID().String()) 40 return handler 41 } 42 43 func (h *ComputeHandler) onAskForBid(stream network.Stream) { 44 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 45 handleStream[compute.AskForBidRequest, compute.AskForBidResponse](ctx, stream, h.computeEndpoint.AskForBid) 46 } 47 48 func (h *ComputeHandler) onBidAccepted(stream network.Stream) { 49 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 50 handleStream[compute.BidAcceptedRequest, compute.BidAcceptedResponse](ctx, stream, h.computeEndpoint.BidAccepted) 51 } 52 53 func (h *ComputeHandler) onBidRejected(stream network.Stream) { 54 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 55 handleStream[compute.BidRejectedRequest, compute.BidRejectedResponse](ctx, stream, h.computeEndpoint.BidRejected) 56 } 57 58 func (h *ComputeHandler) onResultAccepted(stream network.Stream) { 59 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 60 handleStream[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse](ctx, stream, h.computeEndpoint.ResultAccepted) 61 } 62 63 func (h *ComputeHandler) onResultRejected(stream network.Stream) { 64 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 65 handleStream[compute.ResultRejectedRequest, compute.ResultRejectedResponse](ctx, stream, h.computeEndpoint.ResultRejected) 66 } 67 68 func (h *ComputeHandler) onCancelJob(stream network.Stream) { 69 ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String()) 70 handleStream[compute.CancelExecutionRequest, compute.CancelExecutionResponse](ctx, stream, h.computeEndpoint.CancelExecution) 71 } 72 73 //nolint:errcheck 74 func handleStream[Request any, Response any]( 75 ctx context.Context, 76 stream network.Stream, 77 f func(ctx context.Context, r Request) (Response, error)) { 78 if err := stream.Scope().SetService(ComputeServiceName); err != nil { 79 log.Ctx(ctx).Error().Err(err).Msg("error attaching stream to compute service") 80 stream.Reset() 81 return 82 } 83 84 request := new(Request) 85 err := json.NewDecoder(stream).Decode(request) 86 if err != nil { 87 log.Ctx(ctx).Error().Msgf("error decoding %s: %s", reflect.TypeOf(request), err) 88 stream.Reset() 89 return 90 } 91 defer stream.Close() //nolint:errcheck 92 93 response, err := f(ctx, *request) 94 if err != nil { 95 log.Ctx(ctx).Error().Msgf("error delegating %s: %s", reflect.TypeOf(request), err) 96 stream.Reset() 97 return 98 } 99 100 err = json.NewEncoder(stream).Encode(response) 101 if err != nil { 102 log.Ctx(ctx).Error().Msgf("error encoding %s: %s", reflect.TypeOf(response), err) 103 stream.Reset() 104 return 105 } 106 }