github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/bprotocol/compute_proxy.go (about)

     1  package bprotocol
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"reflect"
     8  
     9  	"github.com/filecoin-project/bacalhau/pkg/compute"
    10  	"github.com/libp2p/go-libp2p/core/host"
    11  	"github.com/libp2p/go-libp2p/core/peer"
    12  	"github.com/libp2p/go-libp2p/core/protocol"
    13  )
    14  
    15  type ComputeProxyParams struct {
    16  	Host          host.Host
    17  	LocalEndpoint compute.Endpoint // optional in case this host is also a compute node and to allow local calls
    18  
    19  }
    20  
    21  // ComputeProxy is a proxy to a compute node endpoint that will forward requests to remote compute nodes, or
    22  // to a local compute node if the target peer ID is the same as the local host, and a LocalEndpoint implementation
    23  // is provided.
    24  type ComputeProxy struct {
    25  	host          host.Host
    26  	localEndpoint compute.Endpoint
    27  }
    28  
    29  func NewComputeProxy(params ComputeProxyParams) *ComputeProxy {
    30  	proxy := &ComputeProxy{
    31  		host:          params.Host,
    32  		localEndpoint: params.LocalEndpoint,
    33  	}
    34  	return proxy
    35  }
    36  
    37  func (p *ComputeProxy) RegisterLocalComputeEndpoint(endpoint compute.Endpoint) {
    38  	p.localEndpoint = endpoint
    39  }
    40  
    41  func (p *ComputeProxy) AskForBid(ctx context.Context, request compute.AskForBidRequest) (compute.AskForBidResponse, error) {
    42  	if request.TargetPeerID == p.host.ID().String() {
    43  		if p.localEndpoint == nil {
    44  			return compute.AskForBidResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    45  		}
    46  		return p.localEndpoint.AskForBid(ctx, request)
    47  	}
    48  	return proxyRequest[compute.AskForBidRequest, compute.AskForBidResponse](
    49  		ctx, p.host, request.TargetPeerID, AskForBidProtocolID, request)
    50  }
    51  
    52  func (p *ComputeProxy) BidAccepted(ctx context.Context, request compute.BidAcceptedRequest) (compute.BidAcceptedResponse, error) {
    53  	if request.TargetPeerID == p.host.ID().String() {
    54  		if p.localEndpoint == nil {
    55  			return compute.BidAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    56  		}
    57  		return p.localEndpoint.BidAccepted(ctx, request)
    58  	}
    59  	return proxyRequest[compute.BidAcceptedRequest, compute.BidAcceptedResponse](
    60  		ctx, p.host, request.TargetPeerID, BidAcceptedProtocolID, request)
    61  }
    62  
    63  func (p *ComputeProxy) BidRejected(ctx context.Context, request compute.BidRejectedRequest) (compute.BidRejectedResponse, error) {
    64  	if request.TargetPeerID == p.host.ID().String() {
    65  		if p.localEndpoint == nil {
    66  			return compute.BidRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    67  		}
    68  		return p.localEndpoint.BidRejected(ctx, request)
    69  	}
    70  	return proxyRequest[compute.BidRejectedRequest, compute.BidRejectedResponse](
    71  		ctx, p.host, request.TargetPeerID, BidRejectedProtocolID, request)
    72  }
    73  
    74  func (p *ComputeProxy) ResultAccepted(ctx context.Context, request compute.ResultAcceptedRequest) (compute.ResultAcceptedResponse, error) {
    75  	if request.TargetPeerID == p.host.ID().String() {
    76  		if p.localEndpoint == nil {
    77  			return compute.ResultAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    78  		}
    79  		return p.localEndpoint.ResultAccepted(ctx, request)
    80  	}
    81  	return proxyRequest[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse](
    82  		ctx, p.host, request.TargetPeerID, ResultAcceptedProtocolID, request)
    83  }
    84  
    85  func (p *ComputeProxy) ResultRejected(ctx context.Context, request compute.ResultRejectedRequest) (compute.ResultRejectedResponse, error) {
    86  	if request.TargetPeerID == p.host.ID().String() {
    87  		if p.localEndpoint == nil {
    88  			return compute.ResultRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    89  		}
    90  		return p.localEndpoint.ResultRejected(ctx, request)
    91  	}
    92  	return proxyRequest[compute.ResultRejectedRequest, compute.ResultRejectedResponse](
    93  		ctx, p.host, request.TargetPeerID, ResultRejectedProtocolID, request)
    94  }
    95  
    96  func (p *ComputeProxy) CancelExecution(
    97  	ctx context.Context, request compute.CancelExecutionRequest) (compute.CancelExecutionResponse, error) {
    98  	if request.TargetPeerID == p.host.ID().String() {
    99  		if p.localEndpoint == nil {
   100  			return compute.CancelExecutionResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
   101  		}
   102  		return p.localEndpoint.CancelExecution(ctx, request)
   103  	}
   104  	return proxyRequest[compute.CancelExecutionRequest, compute.CancelExecutionResponse](
   105  		ctx, p.host, request.TargetPeerID, CancelProtocolID, request)
   106  }
   107  
   108  func proxyRequest[Request any, Response any](
   109  	ctx context.Context,
   110  	h host.Host,
   111  	destPeerID string,
   112  	protocolID protocol.ID,
   113  	request Request) (Response, error) {
   114  	// response object
   115  	response := new(Response)
   116  
   117  	// decode the destination peer ID string value
   118  	peerID, err := peer.Decode(destPeerID)
   119  	if err != nil {
   120  		return *response, fmt.Errorf("%s: failed to decode peer ID %s: %w", reflect.TypeOf(request), destPeerID, err)
   121  	}
   122  
   123  	// deserialize the request object
   124  	data, err := json.Marshal(request)
   125  	if err != nil {
   126  		return *response, fmt.Errorf("%s: failed to marshal request: %w", reflect.TypeOf(request), err)
   127  	}
   128  
   129  	// opening a stream to the destination peer
   130  	stream, err := h.NewStream(ctx, peerID, protocolID)
   131  	if err != nil {
   132  		return *response, fmt.Errorf("%s: failed to open stream to peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   133  	}
   134  	defer stream.Close() //nolint:errcheck
   135  	if scopingErr := stream.Scope().SetService(ComputeServiceName); scopingErr != nil {
   136  		stream.Reset() //nolint:errcheck
   137  		return *response, fmt.Errorf("%s: failed to attach stream to compute service: %w", reflect.TypeOf(request), scopingErr)
   138  	}
   139  
   140  	// write the request to the stream
   141  	_, err = stream.Write(data)
   142  	if err != nil {
   143  		stream.Reset() //nolint:errcheck
   144  		return *response, fmt.Errorf("%s: failed to write request to peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   145  	}
   146  
   147  	// Now we read the response that was sent from the dest peer
   148  	err = json.NewDecoder(stream).Decode(response)
   149  	if err != nil {
   150  		stream.Reset() //nolint:errcheck
   151  		return *response, fmt.Errorf("%s: failed to decode response from peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   152  	}
   153  
   154  	return *response, nil
   155  }
   156  
   157  // Compile-time interface check:
   158  var _ compute.Endpoint = (*ComputeProxy)(nil)