github.com/0xPolygon/supernets2-node@v0.0.0-20230711153321-2fe574524eaa/aggregator/prover/prover.go (about)

     1  package prover
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/0xPolygon/supernets2-node/aggregator/metrics"
    11  	"github.com/0xPolygon/supernets2-node/aggregator/pb"
    12  	"github.com/0xPolygon/supernets2-node/config/types"
    13  	"github.com/0xPolygon/supernets2-node/log"
    14  )
    15  
    16  var (
    17  	ErrBadProverResponse    = errors.New("Prover returned wrong type for response")  //nolint:revive
    18  	ErrProverInternalError  = errors.New("Prover returned INTERNAL_ERROR response")  //nolint:revive
    19  	ErrProverCompletedError = errors.New("Prover returned COMPLETED_ERROR response") //nolint:revive
    20  	ErrBadRequest           = errors.New("Prover returned ERROR for a bad request")  //nolint:revive
    21  	ErrUnspecified          = errors.New("Prover returned an UNSPECIFIED response")  //nolint:revive
    22  	ErrUnknown              = errors.New("Prover returned an unknown response")      //nolint:revive
    23  	ErrProofCanceled        = errors.New("Proof has been canceled")                  //nolint:revive
    24  )
    25  
    26  // Prover abstraction of the grpc prover client.
    27  type Prover struct {
    28  	name                      string
    29  	id                        string
    30  	address                   net.Addr
    31  	proofStatePollingInterval types.Duration
    32  	stream                    pb.AggregatorService_ChannelServer
    33  }
    34  
    35  // New returns a new Prover instance.
    36  func New(stream pb.AggregatorService_ChannelServer, addr net.Addr, proofStatePollingInterval types.Duration) (*Prover, error) {
    37  	p := &Prover{
    38  		stream:                    stream,
    39  		address:                   addr,
    40  		proofStatePollingInterval: proofStatePollingInterval,
    41  	}
    42  	status, err := p.Status()
    43  	if err != nil {
    44  		return nil, fmt.Errorf("Failed to retrieve prover id %w", err)
    45  	}
    46  	p.name = status.ProverName
    47  	p.id = status.ProverId
    48  	return p, nil
    49  }
    50  
    51  // Name returns the Prover name.
    52  func (p *Prover) Name() string { return p.name }
    53  
    54  // ID returns the Prover ID.
    55  func (p *Prover) ID() string { return p.id }
    56  
    57  // Addr returns the prover IP address.
    58  func (p *Prover) Addr() string {
    59  	if p.address == nil {
    60  		return ""
    61  	}
    62  	return p.address.String()
    63  }
    64  
    65  // Status gets the prover status.
    66  func (p *Prover) Status() (*pb.GetStatusResponse, error) {
    67  	req := &pb.AggregatorMessage{
    68  		Request: &pb.AggregatorMessage_GetStatusRequest{
    69  			GetStatusRequest: &pb.GetStatusRequest{},
    70  		},
    71  	}
    72  	res, err := p.call(req)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	if msg, ok := res.Response.(*pb.ProverMessage_GetStatusResponse); ok {
    77  		return msg.GetStatusResponse, nil
    78  	}
    79  	return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GetStatusResponse{}, res.Response)
    80  }
    81  
    82  // IsIdle returns true if the prover is idling.
    83  func (p *Prover) IsIdle() (bool, error) {
    84  	status, err := p.Status()
    85  	if err != nil {
    86  		return false, err
    87  	}
    88  	return status.Status == pb.GetStatusResponse_STATUS_IDLE, nil
    89  }
    90  
    91  // SupportsForkID returns true if the prover supports the given fork id.
    92  func (p *Prover) SupportsForkID(forkID uint64) bool {
    93  	status, err := p.Status()
    94  	if err != nil {
    95  		log.Warnf("Error asking status for prover ID %s: %v", p.ID(), err)
    96  		return false
    97  	}
    98  
    99  	log.Debugf("Prover %s supports fork ID %d", p.ID(), status.ForkId)
   100  
   101  	return status.ForkId == forkID
   102  }
   103  
   104  // BatchProof instructs the prover to generate a batch proof for the provided
   105  // input. It returns the ID of the proof being computed.
   106  func (p *Prover) BatchProof(input *pb.InputProver) (*string, error) {
   107  	metrics.WorkingProver()
   108  
   109  	req := &pb.AggregatorMessage{
   110  		Request: &pb.AggregatorMessage_GenBatchProofRequest{
   111  			GenBatchProofRequest: &pb.GenBatchProofRequest{Input: input},
   112  		},
   113  	}
   114  	res, err := p.call(req)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	if msg, ok := res.Response.(*pb.ProverMessage_GenBatchProofResponse); ok {
   120  		switch msg.GenBatchProofResponse.Result {
   121  		case pb.Result_RESULT_UNSPECIFIED:
   122  			return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrUnspecified, input)
   123  		case pb.Result_RESULT_OK:
   124  			return &msg.GenBatchProofResponse.Id, nil
   125  		case pb.Result_RESULT_ERROR:
   126  			return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrBadRequest, input)
   127  		case pb.Result_RESULT_INTERNAL_ERROR:
   128  			return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrProverInternalError, input)
   129  		default:
   130  			return nil, fmt.Errorf("failed to generate proof %s, %w,input %v", msg.GenBatchProofResponse.String(), ErrUnknown, input)
   131  		}
   132  	}
   133  
   134  	return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenBatchProofResponse{}, res.Response)
   135  }
   136  
   137  // AggregatedProof instructs the prover to generate an aggregated proof from
   138  // the two inputs provided. It returns the ID of the proof being computed.
   139  func (p *Prover) AggregatedProof(inputProof1, inputProof2 string) (*string, error) {
   140  	metrics.WorkingProver()
   141  
   142  	req := &pb.AggregatorMessage{
   143  		Request: &pb.AggregatorMessage_GenAggregatedProofRequest{
   144  			GenAggregatedProofRequest: &pb.GenAggregatedProofRequest{
   145  				RecursiveProof_1: inputProof1,
   146  				RecursiveProof_2: inputProof2,
   147  			},
   148  		},
   149  	}
   150  	res, err := p.call(req)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	if msg, ok := res.Response.(*pb.ProverMessage_GenAggregatedProofResponse); ok {
   156  		switch msg.GenAggregatedProofResponse.Result {
   157  		case pb.Result_RESULT_UNSPECIFIED:
   158  			return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s",
   159  				msg.GenAggregatedProofResponse.String(), ErrUnspecified, inputProof1, inputProof2)
   160  		case pb.Result_RESULT_OK:
   161  			return &msg.GenAggregatedProofResponse.Id, nil
   162  		case pb.Result_RESULT_ERROR:
   163  			return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s",
   164  				msg.GenAggregatedProofResponse.String(), ErrBadRequest, inputProof1, inputProof2)
   165  		case pb.Result_RESULT_INTERNAL_ERROR:
   166  			return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s",
   167  				msg.GenAggregatedProofResponse.String(), ErrProverInternalError, inputProof1, inputProof2)
   168  		default:
   169  			return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s",
   170  				msg.GenAggregatedProofResponse.String(), ErrUnknown, inputProof1, inputProof2)
   171  		}
   172  	}
   173  
   174  	return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenAggregatedProofResponse{}, res.Response)
   175  }
   176  
   177  // FinalProof instructs the prover to generate a final proof for the given
   178  // input. It returns the ID of the proof being computed.
   179  func (p *Prover) FinalProof(inputProof string, aggregatorAddr string) (*string, error) {
   180  	metrics.WorkingProver()
   181  
   182  	req := &pb.AggregatorMessage{
   183  		Request: &pb.AggregatorMessage_GenFinalProofRequest{
   184  			GenFinalProofRequest: &pb.GenFinalProofRequest{
   185  				RecursiveProof: inputProof,
   186  				AggregatorAddr: aggregatorAddr,
   187  			},
   188  		},
   189  	}
   190  	res, err := p.call(req)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	if msg, ok := res.Response.(*pb.ProverMessage_GenFinalProofResponse); ok {
   196  		switch msg.GenFinalProofResponse.Result {
   197  		case pb.Result_RESULT_UNSPECIFIED:
   198  			return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s",
   199  				msg.GenFinalProofResponse.String(), ErrUnspecified, inputProof)
   200  		case pb.Result_RESULT_OK:
   201  			return &msg.GenFinalProofResponse.Id, nil
   202  		case pb.Result_RESULT_ERROR:
   203  			return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s",
   204  				msg.GenFinalProofResponse.String(), ErrBadRequest, inputProof)
   205  		case pb.Result_RESULT_INTERNAL_ERROR:
   206  			return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s",
   207  				msg.GenFinalProofResponse.String(), ErrProverInternalError, inputProof)
   208  		default:
   209  			return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s",
   210  				msg.GenFinalProofResponse.String(), ErrUnknown, inputProof)
   211  		}
   212  	}
   213  	return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenFinalProofResponse{}, res.Response)
   214  }
   215  
   216  // CancelProofRequest asks the prover to stop the generation of the proof
   217  // matching the provided proofID.
   218  func (p *Prover) CancelProofRequest(proofID string) error {
   219  	req := &pb.AggregatorMessage{
   220  		Request: &pb.AggregatorMessage_CancelRequest{
   221  			CancelRequest: &pb.CancelRequest{Id: proofID},
   222  		},
   223  	}
   224  	res, err := p.call(req)
   225  	if err != nil {
   226  		return err
   227  	}
   228  	if msg, ok := res.Response.(*pb.ProverMessage_CancelResponse); ok {
   229  		switch msg.CancelResponse.Result {
   230  		case pb.Result_RESULT_UNSPECIFIED:
   231  			return fmt.Errorf("failed to cancel proof id [%s], %w, %s",
   232  				proofID, ErrUnspecified, msg.CancelResponse.String())
   233  		case pb.Result_RESULT_OK:
   234  			return nil
   235  		case pb.Result_RESULT_ERROR:
   236  			return fmt.Errorf("failed to cancel proof id [%s], %w, %s",
   237  				proofID, ErrBadRequest, msg.CancelResponse.String())
   238  		case pb.Result_RESULT_INTERNAL_ERROR:
   239  			return fmt.Errorf("failed to cancel proof id [%s], %w, %s",
   240  				proofID, ErrProverInternalError, msg.CancelResponse.String())
   241  		default:
   242  			return fmt.Errorf("failed to cancel proof id [%s], %w, %s",
   243  				proofID, ErrUnknown, msg.CancelResponse.String())
   244  		}
   245  	}
   246  	return fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_CancelResponse{}, res.Response)
   247  }
   248  
   249  // WaitRecursiveProof waits for a recursive proof to be generated by the prover
   250  // and returns it.
   251  func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, error) {
   252  	res, err := p.waitProof(ctx, proofID)
   253  	if err != nil {
   254  		return "", err
   255  	}
   256  	resProof := res.Proof.(*pb.GetProofResponse_RecursiveProof)
   257  	return resProof.RecursiveProof, nil
   258  }
   259  
   260  // WaitFinalProof waits for the final proof to be generated by the prover and
   261  // returns it.
   262  func (p *Prover) WaitFinalProof(ctx context.Context, proofID string) (*pb.FinalProof, error) {
   263  	res, err := p.waitProof(ctx, proofID)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	resProof := res.Proof.(*pb.GetProofResponse_FinalProof)
   268  	return resProof.FinalProof, nil
   269  }
   270  
   271  // waitProof waits for a proof to be generated by the prover and returns the
   272  // prover response.
   273  func (p *Prover) waitProof(ctx context.Context, proofID string) (*pb.GetProofResponse, error) {
   274  	defer metrics.IdlingProver()
   275  
   276  	req := &pb.AggregatorMessage{
   277  		Request: &pb.AggregatorMessage_GetProofRequest{
   278  			GetProofRequest: &pb.GetProofRequest{
   279  				// TODO(pg): set Timeout field?
   280  				Id: proofID,
   281  			},
   282  		},
   283  	}
   284  
   285  	for {
   286  		select {
   287  		case <-ctx.Done():
   288  			return nil, ctx.Err()
   289  		default:
   290  			res, err := p.call(req)
   291  			if err != nil {
   292  				return nil, err
   293  			}
   294  			if msg, ok := res.Response.(*pb.ProverMessage_GetProofResponse); ok {
   295  				switch msg.GetProofResponse.Result {
   296  				case pb.GetProofResponse_RESULT_PENDING:
   297  					time.Sleep(p.proofStatePollingInterval.Duration)
   298  					continue
   299  				case pb.GetProofResponse_RESULT_UNSPECIFIED:
   300  					return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s",
   301  						proofID, ErrUnspecified, msg.GetProofResponse.String())
   302  				case pb.GetProofResponse_RESULT_COMPLETED_OK:
   303  					return msg.GetProofResponse, nil
   304  				case pb.GetProofResponse_RESULT_ERROR:
   305  					return nil, fmt.Errorf("failed to get proof with ID %s, %w, prover response: %s",
   306  						proofID, ErrBadRequest, msg.GetProofResponse.String())
   307  				case pb.GetProofResponse_RESULT_COMPLETED_ERROR:
   308  					return nil, fmt.Errorf("failed to get proof with ID %s, %w, prover response: %s",
   309  						proofID, ErrProverCompletedError, msg.GetProofResponse.String())
   310  				case pb.GetProofResponse_RESULT_INTERNAL_ERROR:
   311  					return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s",
   312  						proofID, ErrProverInternalError, msg.GetProofResponse.String())
   313  				case pb.GetProofResponse_RESULT_CANCEL:
   314  					return nil, fmt.Errorf("proof generation was cancelled for proof ID %s, %w, prover response: %s",
   315  						proofID, ErrProofCanceled, msg.GetProofResponse.String())
   316  				default:
   317  					return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s",
   318  						proofID, ErrUnknown, msg.GetProofResponse.String())
   319  				}
   320  			}
   321  			return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GetProofResponse{}, res.Response)
   322  		}
   323  	}
   324  }
   325  
   326  // call sends a message to the prover and waits to receive the response over
   327  // the connection stream.
   328  func (p *Prover) call(req *pb.AggregatorMessage) (*pb.ProverMessage, error) {
   329  	if err := p.stream.Send(req); err != nil {
   330  		return nil, err
   331  	}
   332  	res, err := p.stream.Recv()
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  	return res, nil
   337  }