github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/simulator/compute_proxy.go (about) 1 package simulator 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "reflect" 8 9 "github.com/filecoin-project/bacalhau/pkg/compute" 10 "github.com/filecoin-project/bacalhau/pkg/transport/bprotocol" 11 "github.com/libp2p/go-libp2p/core/host" 12 "github.com/libp2p/go-libp2p/core/peer" 13 "github.com/libp2p/go-libp2p/core/protocol" 14 "github.com/rs/zerolog/log" 15 ) 16 17 type ComputeProxyParams struct { 18 SimulatorNodeID string 19 Host host.Host 20 LocalEndpoint compute.Endpoint // optional in case this host is also a compute node and to allow local calls 21 22 } 23 24 // ComputeProxy is a proxy to a compute node endpoint that will forward requests to remote compute nodes, or 25 // to a local compute node if the target peer ID is the same as the local host, and a LocalEndpoint implementation 26 // is provided. 27 type ComputeProxy struct { 28 simulatorNodeID string 29 host host.Host 30 localEndpoint compute.Endpoint 31 } 32 33 func NewComputeProxy(params ComputeProxyParams) *ComputeProxy { 34 proxy := &ComputeProxy{ 35 simulatorNodeID: params.SimulatorNodeID, 36 host: params.Host, 37 localEndpoint: params.LocalEndpoint, 38 } 39 log.Info().Msgf("ComputeProxy created with simulator node %s", params.SimulatorNodeID) 40 return proxy 41 } 42 43 func (p *ComputeProxy) RegisterLocalComputeEndpoint(endpoint compute.Endpoint) { 44 p.localEndpoint = endpoint 45 } 46 47 func (p *ComputeProxy) AskForBid(ctx context.Context, request compute.AskForBidRequest) (compute.AskForBidResponse, error) { 48 if p.simulatorNodeID == p.host.ID().String() { 49 if p.localEndpoint == nil { 50 return compute.AskForBidResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 51 } 52 return p.localEndpoint.AskForBid(ctx, request) 53 } 54 return proxyRequest[compute.AskForBidRequest, compute.AskForBidResponse]( 55 ctx, p.host, p.simulatorNodeID, bprotocol.AskForBidProtocolID, request) 56 } 57 58 func (p *ComputeProxy) BidAccepted(ctx context.Context, request compute.BidAcceptedRequest) (compute.BidAcceptedResponse, error) { 59 if p.simulatorNodeID == p.host.ID().String() { 60 if p.localEndpoint == nil { 61 return compute.BidAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 62 } 63 return p.localEndpoint.BidAccepted(ctx, request) 64 } 65 return proxyRequest[compute.BidAcceptedRequest, compute.BidAcceptedResponse]( 66 ctx, p.host, p.simulatorNodeID, bprotocol.BidAcceptedProtocolID, request) 67 } 68 69 func (p *ComputeProxy) BidRejected(ctx context.Context, request compute.BidRejectedRequest) (compute.BidRejectedResponse, error) { 70 if p.simulatorNodeID == p.host.ID().String() { 71 if p.localEndpoint == nil { 72 return compute.BidRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 73 } 74 return p.localEndpoint.BidRejected(ctx, request) 75 } 76 return proxyRequest[compute.BidRejectedRequest, compute.BidRejectedResponse]( 77 ctx, p.host, p.simulatorNodeID, bprotocol.BidRejectedProtocolID, request) 78 } 79 80 func (p *ComputeProxy) ResultAccepted(ctx context.Context, request compute.ResultAcceptedRequest) (compute.ResultAcceptedResponse, error) { 81 if p.simulatorNodeID == p.host.ID().String() { 82 if p.localEndpoint == nil { 83 return compute.ResultAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 84 } 85 return p.localEndpoint.ResultAccepted(ctx, request) 86 } 87 return proxyRequest[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse]( 88 ctx, p.host, p.simulatorNodeID, bprotocol.ResultAcceptedProtocolID, request) 89 } 90 91 func (p *ComputeProxy) ResultRejected(ctx context.Context, request compute.ResultRejectedRequest) (compute.ResultRejectedResponse, error) { 92 if p.simulatorNodeID == p.host.ID().String() { 93 if p.localEndpoint == nil { 94 return compute.ResultRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 95 } 96 return p.localEndpoint.ResultRejected(ctx, request) 97 } 98 return proxyRequest[compute.ResultRejectedRequest, compute.ResultRejectedResponse]( 99 ctx, p.host, p.simulatorNodeID, bprotocol.ResultRejectedProtocolID, request) 100 } 101 102 func (p *ComputeProxy) CancelExecution( 103 ctx context.Context, request compute.CancelExecutionRequest) (compute.CancelExecutionResponse, error) { 104 if p.simulatorNodeID == p.host.ID().String() { 105 if p.localEndpoint == nil { 106 return compute.CancelExecutionResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 107 } 108 return p.localEndpoint.CancelExecution(ctx, request) 109 } 110 return proxyRequest[compute.CancelExecutionRequest, compute.CancelExecutionResponse]( 111 ctx, p.host, p.simulatorNodeID, bprotocol.CancelProtocolID, request) 112 } 113 114 func proxyRequest[Request any, Response any]( 115 ctx context.Context, 116 h host.Host, 117 destPeerID string, 118 protocolID protocol.ID, 119 request Request) (Response, error) { 120 log.Ctx(ctx).Info().Msgf("Forwarding request %+v to %s", request, destPeerID) 121 // response object 122 response := new(Response) 123 124 // decode the destination peer ID string value 125 peerID, err := peer.Decode(destPeerID) 126 if err != nil { 127 return *response, fmt.Errorf("%s: failed to decode peer ID %s: %w", reflect.TypeOf(request), destPeerID, err) 128 } 129 130 // deserialize the request object 131 data, err := json.Marshal(request) 132 if err != nil { 133 return *response, fmt.Errorf("%s: failed to marshal request: %w", reflect.TypeOf(request), err) 134 } 135 136 // opening a stream to the destination peer 137 stream, err := h.NewStream(ctx, peerID, protocolID) 138 if err != nil { 139 return *response, fmt.Errorf("%s: failed to open stream to peer %s: %w", reflect.TypeOf(request), destPeerID, err) 140 } 141 defer stream.Close() //nolint:errcheck 142 143 // write the request to the stream 144 _, err = stream.Write(data) 145 if err != nil { 146 stream.Reset() //nolint:errcheck 147 return *response, fmt.Errorf("%s: failed to write request to peer %s: %w", reflect.TypeOf(request), destPeerID, err) 148 } 149 150 // Now we read the response that was sent from the dest peer 151 err = json.NewDecoder(stream).Decode(response) 152 if err != nil { 153 stream.Reset() //nolint:errcheck 154 return *response, fmt.Errorf("%s: failed to decode response from peer %s: %w", reflect.TypeOf(request), destPeerID, err) 155 } 156 157 return *response, nil 158 } 159 160 // Compile-time interface check: 161 var _ compute.Endpoint = (*ComputeProxy)(nil)