github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/compute_proxy.go (about) 1 package bprotocol 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "reflect" 8 9 "github.com/filecoin-project/bacalhau/pkg/compute" 10 "github.com/libp2p/go-libp2p/core/host" 11 "github.com/libp2p/go-libp2p/core/peer" 12 "github.com/libp2p/go-libp2p/core/protocol" 13 ) 14 15 type ComputeProxyParams struct { 16 Host host.Host 17 LocalEndpoint compute.Endpoint // optional in case this host is also a compute node and to allow local calls 18 19 } 20 21 // ComputeProxy is a proxy to a compute node endpoint that will forward requests to remote compute nodes, or 22 // to a local compute node if the target peer ID is the same as the local host, and a LocalEndpoint implementation 23 // is provided. 24 type ComputeProxy struct { 25 host host.Host 26 localEndpoint compute.Endpoint 27 } 28 29 func NewComputeProxy(params ComputeProxyParams) *ComputeProxy { 30 proxy := &ComputeProxy{ 31 host: params.Host, 32 localEndpoint: params.LocalEndpoint, 33 } 34 return proxy 35 } 36 37 func (p *ComputeProxy) RegisterLocalComputeEndpoint(endpoint compute.Endpoint) { 38 p.localEndpoint = endpoint 39 } 40 41 func (p *ComputeProxy) AskForBid(ctx context.Context, request compute.AskForBidRequest) (compute.AskForBidResponse, error) { 42 if request.TargetPeerID == p.host.ID().String() { 43 if p.localEndpoint == nil { 44 return compute.AskForBidResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 45 } 46 return p.localEndpoint.AskForBid(ctx, request) 47 } 48 return proxyRequest[compute.AskForBidRequest, compute.AskForBidResponse]( 49 ctx, p.host, request.TargetPeerID, AskForBidProtocolID, request) 50 } 51 52 func (p *ComputeProxy) BidAccepted(ctx context.Context, request compute.BidAcceptedRequest) (compute.BidAcceptedResponse, error) { 53 if request.TargetPeerID == p.host.ID().String() { 54 if p.localEndpoint == nil { 55 return compute.BidAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 56 } 57 return p.localEndpoint.BidAccepted(ctx, request) 58 } 59 return proxyRequest[compute.BidAcceptedRequest, compute.BidAcceptedResponse]( 60 ctx, p.host, request.TargetPeerID, BidAcceptedProtocolID, request) 61 } 62 63 func (p *ComputeProxy) BidRejected(ctx context.Context, request compute.BidRejectedRequest) (compute.BidRejectedResponse, error) { 64 if request.TargetPeerID == p.host.ID().String() { 65 if p.localEndpoint == nil { 66 return compute.BidRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 67 } 68 return p.localEndpoint.BidRejected(ctx, request) 69 } 70 return proxyRequest[compute.BidRejectedRequest, compute.BidRejectedResponse]( 71 ctx, p.host, request.TargetPeerID, BidRejectedProtocolID, request) 72 } 73 74 func (p *ComputeProxy) ResultAccepted(ctx context.Context, request compute.ResultAcceptedRequest) (compute.ResultAcceptedResponse, error) { 75 if request.TargetPeerID == p.host.ID().String() { 76 if p.localEndpoint == nil { 77 return compute.ResultAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 78 } 79 return p.localEndpoint.ResultAccepted(ctx, request) 80 } 81 return proxyRequest[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse]( 82 ctx, p.host, request.TargetPeerID, ResultAcceptedProtocolID, request) 83 } 84 85 func (p *ComputeProxy) ResultRejected(ctx context.Context, request compute.ResultRejectedRequest) (compute.ResultRejectedResponse, error) { 86 if request.TargetPeerID == p.host.ID().String() { 87 if p.localEndpoint == nil { 88 return compute.ResultRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 89 } 90 return p.localEndpoint.ResultRejected(ctx, request) 91 } 92 return proxyRequest[compute.ResultRejectedRequest, compute.ResultRejectedResponse]( 93 ctx, p.host, request.TargetPeerID, ResultRejectedProtocolID, request) 94 } 95 96 func (p *ComputeProxy) CancelExecution( 97 ctx context.Context, request compute.CancelExecutionRequest) (compute.CancelExecutionResponse, error) { 98 if request.TargetPeerID == p.host.ID().String() { 99 if p.localEndpoint == nil { 100 return compute.CancelExecutionResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided") 101 } 102 return p.localEndpoint.CancelExecution(ctx, request) 103 } 104 return proxyRequest[compute.CancelExecutionRequest, compute.CancelExecutionResponse]( 105 ctx, p.host, request.TargetPeerID, CancelProtocolID, request) 106 } 107 108 func proxyRequest[Request any, Response any]( 109 ctx context.Context, 110 h host.Host, 111 destPeerID string, 112 protocolID protocol.ID, 113 request Request) (Response, error) { 114 // response object 115 response := new(Response) 116 117 // decode the destination peer ID string value 118 peerID, err := peer.Decode(destPeerID) 119 if err != nil { 120 return *response, fmt.Errorf("%s: failed to decode peer ID %s: %w", reflect.TypeOf(request), destPeerID, err) 121 } 122 123 // deserialize the request object 124 data, err := json.Marshal(request) 125 if err != nil { 126 return *response, fmt.Errorf("%s: failed to marshal request: %w", reflect.TypeOf(request), err) 127 } 128 129 // opening a stream to the destination peer 130 stream, err := h.NewStream(ctx, peerID, protocolID) 131 if err != nil { 132 return *response, fmt.Errorf("%s: failed to open stream to peer %s: %w", reflect.TypeOf(request), destPeerID, err) 133 } 134 defer stream.Close() //nolint:errcheck 135 if scopingErr := stream.Scope().SetService(ComputeServiceName); scopingErr != nil { 136 stream.Reset() //nolint:errcheck 137 return *response, fmt.Errorf("%s: failed to attach stream to compute service: %w", reflect.TypeOf(request), scopingErr) 138 } 139 140 // write the request to the stream 141 _, err = stream.Write(data) 142 if err != nil { 143 stream.Reset() //nolint:errcheck 144 return *response, fmt.Errorf("%s: failed to write request to peer %s: %w", reflect.TypeOf(request), destPeerID, err) 145 } 146 147 // Now we read the response that was sent from the dest peer 148 err = json.NewDecoder(stream).Decode(response) 149 if err != nil { 150 stream.Reset() //nolint:errcheck 151 return *response, fmt.Errorf("%s: failed to decode response from peer %s: %w", reflect.TypeOf(request), destPeerID, err) 152 } 153 154 return *response, nil 155 } 156 157 // Compile-time interface check: 158 var _ compute.Endpoint = (*ComputeProxy)(nil)