github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/simulator/callback_proxy.go (about) 1 package simulator 2 3 import ( 4 "context" 5 "encoding/json" 6 "reflect" 7 8 "github.com/filecoin-project/bacalhau/pkg/compute" 9 "github.com/filecoin-project/bacalhau/pkg/logger" 10 "github.com/filecoin-project/bacalhau/pkg/transport/bprotocol" 11 "github.com/libp2p/go-libp2p/core/host" 12 "github.com/libp2p/go-libp2p/core/peer" 13 "github.com/libp2p/go-libp2p/core/protocol" 14 "github.com/rs/zerolog/log" 15 ) 16 17 type CallbackProxyParams struct { 18 SimulatorNodeID string 19 Host host.Host 20 LocalCallback compute.Callback 21 } 22 23 // CallbackProxy is a proxy for a compute.Callback that can be used to send compute callbacks to the requester node, 24 // such as when the execution is completed or when a failure occurs. 25 // The proxy can forward callbacks to a remote requester node, or locally if the node is the requester and a 26 // LocalCallback is provided. 27 type CallbackProxy struct { 28 simulatorNodeID string 29 host host.Host 30 localCallback compute.Callback 31 } 32 33 func NewCallbackProxy(params CallbackProxyParams) *CallbackProxy { 34 proxy := &CallbackProxy{ 35 simulatorNodeID: params.SimulatorNodeID, 36 host: params.Host, 37 localCallback: params.LocalCallback, 38 } 39 log.Info().Msgf("CallbackProxy created with simulator node %s", params.SimulatorNodeID) 40 return proxy 41 } 42 43 func (p *CallbackProxy) RegisterLocalComputeCallback(callback compute.Callback) { 44 p.localCallback = callback 45 } 46 47 func (p *CallbackProxy) OnRunComplete(ctx context.Context, result compute.RunResult) { 48 proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnRunComplete, result, func(ctx2 context.Context) { 49 p.localCallback.OnRunComplete(ctx2, result) 50 }) 51 } 52 53 func (p *CallbackProxy) OnPublishComplete(ctx context.Context, result compute.PublishResult) { 54 proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnPublishComplete, result, func(ctx2 context.Context) { 55 p.localCallback.OnPublishComplete(ctx2, result) 56 }) 57 } 58 59 func (p *CallbackProxy) OnCancelComplete(ctx context.Context, result compute.CancelResult) { 60 proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnCancelComplete, result, func(ctx2 context.Context) { 61 p.localCallback.OnCancelComplete(ctx2, result) 62 }) 63 } 64 65 func (p *CallbackProxy) OnComputeFailure(ctx context.Context, result compute.ComputeError) { 66 proxyCallbackRequest(ctx, p, result.RoutingMetadata, bprotocol.OnComputeFailure, result, func(ctx2 context.Context) { 67 p.localCallback.OnComputeFailure(ctx2, result) 68 }) 69 } 70 71 func proxyCallbackRequest( 72 ctx context.Context, 73 p *CallbackProxy, 74 resultInfo compute.RoutingMetadata, 75 protocolID protocol.ID, 76 request interface{}, 77 selfDialFunc func(ctx2 context.Context)) { 78 if p.simulatorNodeID == p.host.ID().String() { 79 if p.localCallback == nil { 80 log.Ctx(ctx).Error().Msgf("unable to dial to self, unless a local compute callback is provided") 81 } else { 82 // TODO: validate which context to user here, and whether running in a goroutine is ok 83 ctx2 := logger.ContextWithNodeIDLogger(context.Background(), p.host.ID().String()) 84 go selfDialFunc(ctx2) 85 } 86 } else { 87 // decode the destination peer ID string value 88 targetPeerID := p.simulatorNodeID 89 log.Ctx(ctx).Info().Msgf("Forwarding callback %+v to %s", request, targetPeerID) 90 peerID, err := peer.Decode(targetPeerID) 91 if err != nil { 92 log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to decode peer ID %s", reflect.TypeOf(request), targetPeerID) 93 return 94 } 95 96 // deserialize the request object 97 data, err := json.Marshal(request) 98 if err != nil { 99 log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to marshal request", reflect.TypeOf(request)) 100 return 101 } 102 103 // opening a stream to the destination peer 104 stream, err := p.host.NewStream(ctx, peerID, protocolID) 105 if err != nil { 106 log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to open stream to peer %s", reflect.TypeOf(request), targetPeerID) 107 return 108 } 109 defer stream.Close() //nolint:errcheck 110 111 // write the request to the stream 112 _, err = stream.Write(data) 113 if err != nil { 114 stream.Reset() //nolint:errcheck 115 log.Ctx(ctx).Error().Err(err).Msgf("%s: failed to write request to peer %s", reflect.TypeOf(request), targetPeerID) 116 return 117 } 118 } 119 } 120 121 // Compile-time interface check: 122 var _ compute.Callback = (*CallbackProxy)(nil)