github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/compute_handler.go (about)

     1  package bprotocol
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"reflect"
     7  
     8  	"github.com/filecoin-project/bacalhau/pkg/compute"
     9  	"github.com/filecoin-project/bacalhau/pkg/logger"
    10  	"github.com/libp2p/go-libp2p/core/host"
    11  	"github.com/libp2p/go-libp2p/core/network"
    12  	"github.com/rs/zerolog/log"
    13  )
    14  
    15  type ComputeHandlerParams struct {
    16  	Host            host.Host
    17  	ComputeEndpoint compute.Endpoint
    18  }
    19  
    20  // ComputeHandler is a handler for compute requests that registers for incoming libp2p requests to Bacalhau compute
    21  // protocol, and delegates the requests to the compute endpoint.
    22  type ComputeHandler struct {
    23  	host            host.Host
    24  	computeEndpoint compute.Endpoint
    25  }
    26  
    27  func NewComputeHandler(params ComputeHandlerParams) *ComputeHandler {
    28  	handler := &ComputeHandler{
    29  		host:            params.Host,
    30  		computeEndpoint: params.ComputeEndpoint,
    31  	}
    32  
    33  	handler.host.SetStreamHandler(AskForBidProtocolID, handler.onAskForBid)
    34  	handler.host.SetStreamHandler(BidAcceptedProtocolID, handler.onBidAccepted)
    35  	handler.host.SetStreamHandler(BidRejectedProtocolID, handler.onBidRejected)
    36  	handler.host.SetStreamHandler(ResultAcceptedProtocolID, handler.onResultAccepted)
    37  	handler.host.SetStreamHandler(ResultRejectedProtocolID, handler.onResultRejected)
    38  	handler.host.SetStreamHandler(CancelProtocolID, handler.onCancelJob)
    39  	log.Debug().Msgf("ComputeHandler started on host %s", handler.host.ID().String())
    40  	return handler
    41  }
    42  
    43  func (h *ComputeHandler) onAskForBid(stream network.Stream) {
    44  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    45  	handleStream[compute.AskForBidRequest, compute.AskForBidResponse](ctx, stream, h.computeEndpoint.AskForBid)
    46  }
    47  
    48  func (h *ComputeHandler) onBidAccepted(stream network.Stream) {
    49  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    50  	handleStream[compute.BidAcceptedRequest, compute.BidAcceptedResponse](ctx, stream, h.computeEndpoint.BidAccepted)
    51  }
    52  
    53  func (h *ComputeHandler) onBidRejected(stream network.Stream) {
    54  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    55  	handleStream[compute.BidRejectedRequest, compute.BidRejectedResponse](ctx, stream, h.computeEndpoint.BidRejected)
    56  }
    57  
    58  func (h *ComputeHandler) onResultAccepted(stream network.Stream) {
    59  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    60  	handleStream[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse](ctx, stream, h.computeEndpoint.ResultAccepted)
    61  }
    62  
    63  func (h *ComputeHandler) onResultRejected(stream network.Stream) {
    64  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    65  	handleStream[compute.ResultRejectedRequest, compute.ResultRejectedResponse](ctx, stream, h.computeEndpoint.ResultRejected)
    66  }
    67  
    68  func (h *ComputeHandler) onCancelJob(stream network.Stream) {
    69  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    70  	handleStream[compute.CancelExecutionRequest, compute.CancelExecutionResponse](ctx, stream, h.computeEndpoint.CancelExecution)
    71  }
    72  
    73  //nolint:errcheck
    74  func handleStream[Request any, Response any](
    75  	ctx context.Context,
    76  	stream network.Stream,
    77  	f func(ctx context.Context, r Request) (Response, error)) {
    78  	if err := stream.Scope().SetService(ComputeServiceName); err != nil {
    79  		log.Ctx(ctx).Error().Err(err).Msg("error attaching stream to compute service")
    80  		stream.Reset()
    81  		return
    82  	}
    83  
    84  	request := new(Request)
    85  	err := json.NewDecoder(stream).Decode(request)
    86  	if err != nil {
    87  		log.Ctx(ctx).Error().Msgf("error decoding %s: %s", reflect.TypeOf(request), err)
    88  		stream.Reset()
    89  		return
    90  	}
    91  	defer stream.Close() //nolint:errcheck
    92  
    93  	response, err := f(ctx, *request)
    94  	if err != nil {
    95  		log.Ctx(ctx).Error().Msgf("error delegating %s: %s", reflect.TypeOf(request), err)
    96  		stream.Reset()
    97  		return
    98  	}
    99  
   100  	err = json.NewEncoder(stream).Encode(response)
   101  	if err != nil {
   102  		log.Ctx(ctx).Error().Msgf("error encoding %s: %s", reflect.TypeOf(response), err)
   103  		stream.Reset()
   104  		return
   105  	}
   106  }