github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/callback_handler.go (about)

     1  package bprotocol
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"reflect"
     7  
     8  	"github.com/filecoin-project/bacalhau/pkg/compute"
     9  	"github.com/filecoin-project/bacalhau/pkg/logger"
    10  	"github.com/libp2p/go-libp2p/core/host"
    11  	"github.com/libp2p/go-libp2p/core/network"
    12  	"github.com/rs/zerolog/log"
    13  )
    14  
    15  type CallbackHandlerParams struct {
    16  	Host     host.Host
    17  	Callback compute.Callback
    18  }
    19  
    20  // CallbackHandler is a handler for callback events that registers for incoming libp2p requests to Bacalhau callback
    21  // protocol, and delegates the handling of the request to the provided callback.
    22  type CallbackHandler struct {
    23  	host     host.Host
    24  	callback compute.Callback
    25  }
    26  
    27  func NewCallbackHandler(params CallbackHandlerParams) *CallbackHandler {
    28  	handler := &CallbackHandler{
    29  		host:     params.Host,
    30  		callback: params.Callback,
    31  	}
    32  
    33  	handler.host.SetStreamHandler(OnRunComplete, handler.onRunSuccess)
    34  	handler.host.SetStreamHandler(OnPublishComplete, handler.onPublishSuccess)
    35  	handler.host.SetStreamHandler(OnCancelComplete, handler.onCancelSuccess)
    36  	handler.host.SetStreamHandler(OnComputeFailure, handler.onComputeFailure)
    37  	return handler
    38  }
    39  
    40  func (h *CallbackHandler) onRunSuccess(stream network.Stream) {
    41  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    42  	handleCallbackStream[compute.RunResult](ctx, stream, h.callback.OnRunComplete)
    43  }
    44  func (h *CallbackHandler) onPublishSuccess(stream network.Stream) {
    45  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    46  	handleCallbackStream[compute.PublishResult](ctx, stream, h.callback.OnPublishComplete)
    47  }
    48  
    49  func (h *CallbackHandler) onCancelSuccess(stream network.Stream) {
    50  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    51  	handleCallbackStream[compute.CancelResult](ctx, stream, h.callback.OnCancelComplete)
    52  }
    53  
    54  func (h *CallbackHandler) onComputeFailure(stream network.Stream) {
    55  	ctx := logger.ContextWithNodeIDLogger(context.Background(), h.host.ID().String())
    56  	handleCallbackStream[compute.ComputeError](ctx, stream, h.callback.OnComputeFailure)
    57  }
    58  
    59  //nolint:errcheck
    60  func handleCallbackStream[Request any](
    61  	ctx context.Context,
    62  	stream network.Stream,
    63  	f func(ctx context.Context, r Request)) {
    64  	ctx = logger.ContextWithNodeIDLogger(ctx, stream.Conn().LocalPeer().String())
    65  	if err := stream.Scope().SetService(CallbackServiceName); err != nil {
    66  		log.Ctx(ctx).Error().Err(err).Msg("error attaching stream to requester service")
    67  		stream.Reset()
    68  		return
    69  	}
    70  
    71  	request := new(Request)
    72  	err := json.NewDecoder(stream).Decode(request)
    73  	if err != nil {
    74  		log.Ctx(ctx).Error().Msgf("error decoding %s: %s", reflect.TypeOf(request), err)
    75  		stream.Reset()
    76  		return
    77  	}
    78  	defer stream.Close() //nolint:errcheck
    79  
    80  	// TODO: validate which context to user here, and whether running in a goroutine is ok
    81  	newCtx := logger.ContextWithNodeIDLogger(context.Background(), stream.Conn().LocalPeer().String())
    82  	go f(newCtx, *request)
    83  }