github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/transport/simulator/compute_proxy.go (about)

     1  package simulator
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"reflect"
     8  
     9  	"github.com/filecoin-project/bacalhau/pkg/compute"
    10  	"github.com/filecoin-project/bacalhau/pkg/transport/bprotocol"
    11  	"github.com/libp2p/go-libp2p/core/host"
    12  	"github.com/libp2p/go-libp2p/core/peer"
    13  	"github.com/libp2p/go-libp2p/core/protocol"
    14  	"github.com/rs/zerolog/log"
    15  )
    16  
    17  type ComputeProxyParams struct {
    18  	SimulatorNodeID string
    19  	Host            host.Host
    20  	LocalEndpoint   compute.Endpoint // optional in case this host is also a compute node and to allow local calls
    21  
    22  }
    23  
    24  // ComputeProxy is a proxy to a compute node endpoint that will forward requests to remote compute nodes, or
    25  // to a local compute node if the target peer ID is the same as the local host, and a LocalEndpoint implementation
    26  // is provided.
    27  type ComputeProxy struct {
    28  	simulatorNodeID string
    29  	host            host.Host
    30  	localEndpoint   compute.Endpoint
    31  }
    32  
    33  func NewComputeProxy(params ComputeProxyParams) *ComputeProxy {
    34  	proxy := &ComputeProxy{
    35  		simulatorNodeID: params.SimulatorNodeID,
    36  		host:            params.Host,
    37  		localEndpoint:   params.LocalEndpoint,
    38  	}
    39  	log.Info().Msgf("ComputeProxy created with simulator node %s", params.SimulatorNodeID)
    40  	return proxy
    41  }
    42  
    43  func (p *ComputeProxy) RegisterLocalComputeEndpoint(endpoint compute.Endpoint) {
    44  	p.localEndpoint = endpoint
    45  }
    46  
    47  func (p *ComputeProxy) AskForBid(ctx context.Context, request compute.AskForBidRequest) (compute.AskForBidResponse, error) {
    48  	if p.simulatorNodeID == p.host.ID().String() {
    49  		if p.localEndpoint == nil {
    50  			return compute.AskForBidResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    51  		}
    52  		return p.localEndpoint.AskForBid(ctx, request)
    53  	}
    54  	return proxyRequest[compute.AskForBidRequest, compute.AskForBidResponse](
    55  		ctx, p.host, p.simulatorNodeID, bprotocol.AskForBidProtocolID, request)
    56  }
    57  
    58  func (p *ComputeProxy) BidAccepted(ctx context.Context, request compute.BidAcceptedRequest) (compute.BidAcceptedResponse, error) {
    59  	if p.simulatorNodeID == p.host.ID().String() {
    60  		if p.localEndpoint == nil {
    61  			return compute.BidAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    62  		}
    63  		return p.localEndpoint.BidAccepted(ctx, request)
    64  	}
    65  	return proxyRequest[compute.BidAcceptedRequest, compute.BidAcceptedResponse](
    66  		ctx, p.host, p.simulatorNodeID, bprotocol.BidAcceptedProtocolID, request)
    67  }
    68  
    69  func (p *ComputeProxy) BidRejected(ctx context.Context, request compute.BidRejectedRequest) (compute.BidRejectedResponse, error) {
    70  	if p.simulatorNodeID == p.host.ID().String() {
    71  		if p.localEndpoint == nil {
    72  			return compute.BidRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    73  		}
    74  		return p.localEndpoint.BidRejected(ctx, request)
    75  	}
    76  	return proxyRequest[compute.BidRejectedRequest, compute.BidRejectedResponse](
    77  		ctx, p.host, p.simulatorNodeID, bprotocol.BidRejectedProtocolID, request)
    78  }
    79  
    80  func (p *ComputeProxy) ResultAccepted(ctx context.Context, request compute.ResultAcceptedRequest) (compute.ResultAcceptedResponse, error) {
    81  	if p.simulatorNodeID == p.host.ID().String() {
    82  		if p.localEndpoint == nil {
    83  			return compute.ResultAcceptedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    84  		}
    85  		return p.localEndpoint.ResultAccepted(ctx, request)
    86  	}
    87  	return proxyRequest[compute.ResultAcceptedRequest, compute.ResultAcceptedResponse](
    88  		ctx, p.host, p.simulatorNodeID, bprotocol.ResultAcceptedProtocolID, request)
    89  }
    90  
    91  func (p *ComputeProxy) ResultRejected(ctx context.Context, request compute.ResultRejectedRequest) (compute.ResultRejectedResponse, error) {
    92  	if p.simulatorNodeID == p.host.ID().String() {
    93  		if p.localEndpoint == nil {
    94  			return compute.ResultRejectedResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
    95  		}
    96  		return p.localEndpoint.ResultRejected(ctx, request)
    97  	}
    98  	return proxyRequest[compute.ResultRejectedRequest, compute.ResultRejectedResponse](
    99  		ctx, p.host, p.simulatorNodeID, bprotocol.ResultRejectedProtocolID, request)
   100  }
   101  
   102  func (p *ComputeProxy) CancelExecution(
   103  	ctx context.Context, request compute.CancelExecutionRequest) (compute.CancelExecutionResponse, error) {
   104  	if p.simulatorNodeID == p.host.ID().String() {
   105  		if p.localEndpoint == nil {
   106  			return compute.CancelExecutionResponse{}, fmt.Errorf("unable to dial to self, unless a local compute endpoint is provided")
   107  		}
   108  		return p.localEndpoint.CancelExecution(ctx, request)
   109  	}
   110  	return proxyRequest[compute.CancelExecutionRequest, compute.CancelExecutionResponse](
   111  		ctx, p.host, p.simulatorNodeID, bprotocol.CancelProtocolID, request)
   112  }
   113  
   114  func proxyRequest[Request any, Response any](
   115  	ctx context.Context,
   116  	h host.Host,
   117  	destPeerID string,
   118  	protocolID protocol.ID,
   119  	request Request) (Response, error) {
   120  	log.Ctx(ctx).Info().Msgf("Forwarding request %+v to %s", request, destPeerID)
   121  	// response object
   122  	response := new(Response)
   123  
   124  	// decode the destination peer ID string value
   125  	peerID, err := peer.Decode(destPeerID)
   126  	if err != nil {
   127  		return *response, fmt.Errorf("%s: failed to decode peer ID %s: %w", reflect.TypeOf(request), destPeerID, err)
   128  	}
   129  
   130  	// deserialize the request object
   131  	data, err := json.Marshal(request)
   132  	if err != nil {
   133  		return *response, fmt.Errorf("%s: failed to marshal request: %w", reflect.TypeOf(request), err)
   134  	}
   135  
   136  	// opening a stream to the destination peer
   137  	stream, err := h.NewStream(ctx, peerID, protocolID)
   138  	if err != nil {
   139  		return *response, fmt.Errorf("%s: failed to open stream to peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   140  	}
   141  	defer stream.Close() //nolint:errcheck
   142  
   143  	// write the request to the stream
   144  	_, err = stream.Write(data)
   145  	if err != nil {
   146  		stream.Reset() //nolint:errcheck
   147  		return *response, fmt.Errorf("%s: failed to write request to peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   148  	}
   149  
   150  	// Now we read the response that was sent from the dest peer
   151  	err = json.NewDecoder(stream).Decode(response)
   152  	if err != nil {
   153  		stream.Reset() //nolint:errcheck
   154  		return *response, fmt.Errorf("%s: failed to decode response from peer %s: %w", reflect.TypeOf(request), destPeerID, err)
   155  	}
   156  
   157  	return *response, nil
   158  }
   159  
   160  // Compile-time interface check:
   161  var _ compute.Endpoint = (*ComputeProxy)(nil)