github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/simulator/callback_proxy.go (about)

     1  package simulator
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"reflect"
     7  
     8  	"github.com/filecoin-project/bacalhau/pkg/compute"
     9  	"github.com/filecoin-project/bacalhau/pkg/logger"
    10  	"github.com/filecoin-project/bacalhau/pkg/transport/bprotocol"
    11  	"github.com/libp2p/go-libp2p/core/host"
    12  	"github.com/libp2p/go-libp2p/core/peer"
    13  	"github.com/libp2p/go-libp2p/core/protocol"
    14  	"github.com/rs/zerolog/log"
    15  )
    16  
    17  type CallbackProxyParams struct {
    18  	SimulatorNodeID string
    19  	Host            host.Host
    20  	LocalCallback   compute.Callback
    21  }
    22  
    23  // CallbackProxy is a proxy for a compute.Callback that can be used to send compute callbacks to the requester node,
    24  // such as when the execution is completed or when a failure occurs.
    25  // The proxy can forward callbacks to a remote requester node, or locally if the node is the requester and a
    26  // LocalCallback is provided.
    27  type CallbackProxy struct {
    28  	simulatorNodeID string
    29  	host            host.Host
    30  	localCallback   compute.Callback
    31  }
    32  
    33  func NewCallbackProxy(params CallbackProxyParams) *CallbackProxy {
    34  	proxy := &CallbackProxy{
    35  		simulatorNodeID: params.SimulatorNodeID,
    36  		host:            params.Host,
    37  		localCallback:   params.LocalCallback,
    38  	}
    39  	log.Info().Msgf("CallbackProxy created with simulator node %s", params.SimulatorNodeID)
    40  	return proxy
    41  }
    42  
    43  func (p *CallbackProxy) RegisterLocalComputeCallback(callback compute.Callback) {
    44  	p.localCallback = callback
    45  }
    46  
    47  func (p *CallbackProxy) OnRunComplete(ctx context.Context, result compute.RunResult) {
    48  	proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnRunComplete, result, func(ctx2 context.Context) {
    49  		p.localCallback.OnRunComplete(ctx2, result)
    50  	})
    51  }
    52  
    53  func (p *CallbackProxy) OnPublishComplete(ctx context.Context, result compute.PublishResult) {
    54  	proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnPublishComplete, result, func(ctx2 context.Context) {
    55  		p.localCallback.OnPublishComplete(ctx2, result)
    56  	})
    57  }
    58  
    59  func (p *CallbackProxy) OnCancelComplete(ctx context.Context, result compute.CancelResult) {
    60  	proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnCancelComplete, result, func(ctx2 context.Context) {
    61  		p.localCallback.OnCancelComplete(ctx2, result)
    62  	})
    63  }
    64  
    65  func (p *CallbackProxy) OnComputeFailure(ctx context.Context, result compute.ComputeError) {
    66  	proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnComputeFailure, result, func(ctx2 context.Context) {
    67  		p.localCallback.OnComputeFailure(ctx2, result)
    68  	})
    69  }
    70  
    71  func proxyCallbackRequest(
    72  	ctx context.Context,
    73  	p *CallbackProxy,
    74  	resultInfo compute.RoutingMetadata,
    75  	protocolID protocol.ID,
    76  	request interface{},
    77  	selfDialFunc func(ctx2 context.Context)) {
    78  	if p.simulatorNodeID == p.host.ID().String() {
    79  		if p.localCallback == nil {
    80  			log.Ctx(ctx).Error().Msgf("unable to dial to self, unless a local compute callback is provided")
    81  		} else {
    82  			// TODO: validate which context to user here, and whether running in a goroutine is ok
    83  			ctx2 := logger.ContextWithNodeIDLogger(context.Background(), p.host.ID().String())
    84  			go selfDialFunc(ctx2)
    85  		}
    86  	} else {
    87  		// decode the destination peer ID string value
    88  		targetPeerID := p.simulatorNodeID
    89  		log.Ctx(ctx).Info().Msgf("Forwarding callback %+v to %s", request, targetPeerID)
    90  		peerID, err := peer.Decode(targetPeerID)
    91  		if err != nil {
    92  			log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to decode peer ID %s", reflect.TypeOf(request), targetPeerID)
    93  			return
    94  		}
    95  
    96  		// deserialize the request object
    97  		data, err := json.Marshal(request)
    98  		if err != nil {
    99  			log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to marshal request", reflect.TypeOf(request))
   100  			return
   101  		}
   102  
   103  		// opening a stream to the destination peer
   104  		stream, err := p.host.NewStream(ctx, peerID, protocolID)
   105  		if err != nil {
   106  			log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to open stream to peer %s", reflect.TypeOf(request), targetPeerID)
   107  			return
   108  		}
   109  		defer stream.Close() //nolint:errcheck
   110  
   111  		// write the request to the stream
   112  		_, err = stream.Write(data)
   113  		if err != nil {
   114  			stream.Reset() //nolint:errcheck
   115  			log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to write request to peer %s", reflect.TypeOf(request), targetPeerID)
   116  			return
   117  		}
   118  	}
   119  }
   120  
   121  // Compile-time interface check:
   122  var _ compute.Callback = (*CallbackProxy)(nil)