github.com/0xPolygon/supernets2-node@v0.0.0-20230711153321-2fe574524eaa/aggregator/prover/prover.go (about) 1 package prover 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "time" 9 10 "github.com/0xPolygon/supernets2-node/aggregator/metrics" 11 "github.com/0xPolygon/supernets2-node/aggregator/pb" 12 "github.com/0xPolygon/supernets2-node/config/types" 13 "github.com/0xPolygon/supernets2-node/log" 14 ) 15 16 var ( 17 ErrBadProverResponse = errors.New("Prover returned wrong type for response") //nolint:revive 18 ErrProverInternalError = errors.New("Prover returned INTERNAL_ERROR response") //nolint:revive 19 ErrProverCompletedError = errors.New("Prover returned COMPLETED_ERROR response") //nolint:revive 20 ErrBadRequest = errors.New("Prover returned ERROR for a bad request") //nolint:revive 21 ErrUnspecified = errors.New("Prover returned an UNSPECIFIED response") //nolint:revive 22 ErrUnknown = errors.New("Prover returned an unknown response") //nolint:revive 23 ErrProofCanceled = errors.New("Proof has been canceled") //nolint:revive 24 ) 25 26 // Prover abstraction of the grpc prover client. 27 type Prover struct { 28 name string 29 id string 30 address net.Addr 31 proofStatePollingInterval types.Duration 32 stream pb.AggregatorService_ChannelServer 33 } 34 35 // New returns a new Prover instance. 36 func New(stream pb.AggregatorService_ChannelServer, addr net.Addr, proofStatePollingInterval types.Duration) (*Prover, error) { 37 p := &Prover{ 38 stream: stream, 39 address: addr, 40 proofStatePollingInterval: proofStatePollingInterval, 41 } 42 status, err := p.Status() 43 if err != nil { 44 return nil, fmt.Errorf("Failed to retrieve prover id %w", err) 45 } 46 p.name = status.ProverName 47 p.id = status.ProverId 48 return p, nil 49 } 50 51 // Name returns the Prover name. 52 func (p *Prover) Name() string { return p.name } 53 54 // ID returns the Prover ID. 55 func (p *Prover) ID() string { return p.id } 56 57 // Addr returns the prover IP address. 58 func (p *Prover) Addr() string { 59 if p.address == nil { 60 return "" 61 } 62 return p.address.String() 63 } 64 65 // Status gets the prover status. 66 func (p *Prover) Status() (*pb.GetStatusResponse, error) { 67 req := &pb.AggregatorMessage{ 68 Request: &pb.AggregatorMessage_GetStatusRequest{ 69 GetStatusRequest: &pb.GetStatusRequest{}, 70 }, 71 } 72 res, err := p.call(req) 73 if err != nil { 74 return nil, err 75 } 76 if msg, ok := res.Response.(*pb.ProverMessage_GetStatusResponse); ok { 77 return msg.GetStatusResponse, nil 78 } 79 return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GetStatusResponse{}, res.Response) 80 } 81 82 // IsIdle returns true if the prover is idling. 83 func (p *Prover) IsIdle() (bool, error) { 84 status, err := p.Status() 85 if err != nil { 86 return false, err 87 } 88 return status.Status == pb.GetStatusResponse_STATUS_IDLE, nil 89 } 90 91 // SupportsForkID returns true if the prover supports the given fork id. 92 func (p *Prover) SupportsForkID(forkID uint64) bool { 93 status, err := p.Status() 94 if err != nil { 95 log.Warnf("Error asking status for prover ID %s: %v", p.ID(), err) 96 return false 97 } 98 99 log.Debugf("Prover %s supports fork ID %d", p.ID(), status.ForkId) 100 101 return status.ForkId == forkID 102 } 103 104 // BatchProof instructs the prover to generate a batch proof for the provided 105 // input. It returns the ID of the proof being computed. 106 func (p *Prover) BatchProof(input *pb.InputProver) (*string, error) { 107 metrics.WorkingProver() 108 109 req := &pb.AggregatorMessage{ 110 Request: &pb.AggregatorMessage_GenBatchProofRequest{ 111 GenBatchProofRequest: &pb.GenBatchProofRequest{Input: input}, 112 }, 113 } 114 res, err := p.call(req) 115 if err != nil { 116 return nil, err 117 } 118 119 if msg, ok := res.Response.(*pb.ProverMessage_GenBatchProofResponse); ok { 120 switch msg.GenBatchProofResponse.Result { 121 case pb.Result_RESULT_UNSPECIFIED: 122 return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrUnspecified, input) 123 case pb.Result_RESULT_OK: 124 return &msg.GenBatchProofResponse.Id, nil 125 case pb.Result_RESULT_ERROR: 126 return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrBadRequest, input) 127 case pb.Result_RESULT_INTERNAL_ERROR: 128 return nil, fmt.Errorf("failed to generate proof %s, %w, input %v", msg.GenBatchProofResponse.String(), ErrProverInternalError, input) 129 default: 130 return nil, fmt.Errorf("failed to generate proof %s, %w,input %v", msg.GenBatchProofResponse.String(), ErrUnknown, input) 131 } 132 } 133 134 return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenBatchProofResponse{}, res.Response) 135 } 136 137 // AggregatedProof instructs the prover to generate an aggregated proof from 138 // the two inputs provided. It returns the ID of the proof being computed. 139 func (p *Prover) AggregatedProof(inputProof1, inputProof2 string) (*string, error) { 140 metrics.WorkingProver() 141 142 req := &pb.AggregatorMessage{ 143 Request: &pb.AggregatorMessage_GenAggregatedProofRequest{ 144 GenAggregatedProofRequest: &pb.GenAggregatedProofRequest{ 145 RecursiveProof_1: inputProof1, 146 RecursiveProof_2: inputProof2, 147 }, 148 }, 149 } 150 res, err := p.call(req) 151 if err != nil { 152 return nil, err 153 } 154 155 if msg, ok := res.Response.(*pb.ProverMessage_GenAggregatedProofResponse); ok { 156 switch msg.GenAggregatedProofResponse.Result { 157 case pb.Result_RESULT_UNSPECIFIED: 158 return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s", 159 msg.GenAggregatedProofResponse.String(), ErrUnspecified, inputProof1, inputProof2) 160 case pb.Result_RESULT_OK: 161 return &msg.GenAggregatedProofResponse.Id, nil 162 case pb.Result_RESULT_ERROR: 163 return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s", 164 msg.GenAggregatedProofResponse.String(), ErrBadRequest, inputProof1, inputProof2) 165 case pb.Result_RESULT_INTERNAL_ERROR: 166 return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s", 167 msg.GenAggregatedProofResponse.String(), ErrProverInternalError, inputProof1, inputProof2) 168 default: 169 return nil, fmt.Errorf("failed to aggregate proofs %s, %w, input 1 %s, input 2 %s", 170 msg.GenAggregatedProofResponse.String(), ErrUnknown, inputProof1, inputProof2) 171 } 172 } 173 174 return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenAggregatedProofResponse{}, res.Response) 175 } 176 177 // FinalProof instructs the prover to generate a final proof for the given 178 // input. It returns the ID of the proof being computed. 179 func (p *Prover) FinalProof(inputProof string, aggregatorAddr string) (*string, error) { 180 metrics.WorkingProver() 181 182 req := &pb.AggregatorMessage{ 183 Request: &pb.AggregatorMessage_GenFinalProofRequest{ 184 GenFinalProofRequest: &pb.GenFinalProofRequest{ 185 RecursiveProof: inputProof, 186 AggregatorAddr: aggregatorAddr, 187 }, 188 }, 189 } 190 res, err := p.call(req) 191 if err != nil { 192 return nil, err 193 } 194 195 if msg, ok := res.Response.(*pb.ProverMessage_GenFinalProofResponse); ok { 196 switch msg.GenFinalProofResponse.Result { 197 case pb.Result_RESULT_UNSPECIFIED: 198 return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s", 199 msg.GenFinalProofResponse.String(), ErrUnspecified, inputProof) 200 case pb.Result_RESULT_OK: 201 return &msg.GenFinalProofResponse.Id, nil 202 case pb.Result_RESULT_ERROR: 203 return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s", 204 msg.GenFinalProofResponse.String(), ErrBadRequest, inputProof) 205 case pb.Result_RESULT_INTERNAL_ERROR: 206 return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s", 207 msg.GenFinalProofResponse.String(), ErrProverInternalError, inputProof) 208 default: 209 return nil, fmt.Errorf("failed to generate final proof %s, %w, input %s", 210 msg.GenFinalProofResponse.String(), ErrUnknown, inputProof) 211 } 212 } 213 return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GenFinalProofResponse{}, res.Response) 214 } 215 216 // CancelProofRequest asks the prover to stop the generation of the proof 217 // matching the provided proofID. 218 func (p *Prover) CancelProofRequest(proofID string) error { 219 req := &pb.AggregatorMessage{ 220 Request: &pb.AggregatorMessage_CancelRequest{ 221 CancelRequest: &pb.CancelRequest{Id: proofID}, 222 }, 223 } 224 res, err := p.call(req) 225 if err != nil { 226 return err 227 } 228 if msg, ok := res.Response.(*pb.ProverMessage_CancelResponse); ok { 229 switch msg.CancelResponse.Result { 230 case pb.Result_RESULT_UNSPECIFIED: 231 return fmt.Errorf("failed to cancel proof id [%s], %w, %s", 232 proofID, ErrUnspecified, msg.CancelResponse.String()) 233 case pb.Result_RESULT_OK: 234 return nil 235 case pb.Result_RESULT_ERROR: 236 return fmt.Errorf("failed to cancel proof id [%s], %w, %s", 237 proofID, ErrBadRequest, msg.CancelResponse.String()) 238 case pb.Result_RESULT_INTERNAL_ERROR: 239 return fmt.Errorf("failed to cancel proof id [%s], %w, %s", 240 proofID, ErrProverInternalError, msg.CancelResponse.String()) 241 default: 242 return fmt.Errorf("failed to cancel proof id [%s], %w, %s", 243 proofID, ErrUnknown, msg.CancelResponse.String()) 244 } 245 } 246 return fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_CancelResponse{}, res.Response) 247 } 248 249 // WaitRecursiveProof waits for a recursive proof to be generated by the prover 250 // and returns it. 251 func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, error) { 252 res, err := p.waitProof(ctx, proofID) 253 if err != nil { 254 return "", err 255 } 256 resProof := res.Proof.(*pb.GetProofResponse_RecursiveProof) 257 return resProof.RecursiveProof, nil 258 } 259 260 // WaitFinalProof waits for the final proof to be generated by the prover and 261 // returns it. 262 func (p *Prover) WaitFinalProof(ctx context.Context, proofID string) (*pb.FinalProof, error) { 263 res, err := p.waitProof(ctx, proofID) 264 if err != nil { 265 return nil, err 266 } 267 resProof := res.Proof.(*pb.GetProofResponse_FinalProof) 268 return resProof.FinalProof, nil 269 } 270 271 // waitProof waits for a proof to be generated by the prover and returns the 272 // prover response. 273 func (p *Prover) waitProof(ctx context.Context, proofID string) (*pb.GetProofResponse, error) { 274 defer metrics.IdlingProver() 275 276 req := &pb.AggregatorMessage{ 277 Request: &pb.AggregatorMessage_GetProofRequest{ 278 GetProofRequest: &pb.GetProofRequest{ 279 // TODO(pg): set Timeout field? 280 Id: proofID, 281 }, 282 }, 283 } 284 285 for { 286 select { 287 case <-ctx.Done(): 288 return nil, ctx.Err() 289 default: 290 res, err := p.call(req) 291 if err != nil { 292 return nil, err 293 } 294 if msg, ok := res.Response.(*pb.ProverMessage_GetProofResponse); ok { 295 switch msg.GetProofResponse.Result { 296 case pb.GetProofResponse_RESULT_PENDING: 297 time.Sleep(p.proofStatePollingInterval.Duration) 298 continue 299 case pb.GetProofResponse_RESULT_UNSPECIFIED: 300 return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s", 301 proofID, ErrUnspecified, msg.GetProofResponse.String()) 302 case pb.GetProofResponse_RESULT_COMPLETED_OK: 303 return msg.GetProofResponse, nil 304 case pb.GetProofResponse_RESULT_ERROR: 305 return nil, fmt.Errorf("failed to get proof with ID %s, %w, prover response: %s", 306 proofID, ErrBadRequest, msg.GetProofResponse.String()) 307 case pb.GetProofResponse_RESULT_COMPLETED_ERROR: 308 return nil, fmt.Errorf("failed to get proof with ID %s, %w, prover response: %s", 309 proofID, ErrProverCompletedError, msg.GetProofResponse.String()) 310 case pb.GetProofResponse_RESULT_INTERNAL_ERROR: 311 return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s", 312 proofID, ErrProverInternalError, msg.GetProofResponse.String()) 313 case pb.GetProofResponse_RESULT_CANCEL: 314 return nil, fmt.Errorf("proof generation was cancelled for proof ID %s, %w, prover response: %s", 315 proofID, ErrProofCanceled, msg.GetProofResponse.String()) 316 default: 317 return nil, fmt.Errorf("failed to get proof ID: %s, %w, prover response: %s", 318 proofID, ErrUnknown, msg.GetProofResponse.String()) 319 } 320 } 321 return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &pb.ProverMessage_GetProofResponse{}, res.Response) 322 } 323 } 324 } 325 326 // call sends a message to the prover and waits to receive the response over 327 // the connection stream. 328 func (p *Prover) call(req *pb.AggregatorMessage) (*pb.ProverMessage, error) { 329 if err := p.stream.Send(req); err != nil { 330 return nil, err 331 } 332 res, err := p.stream.Recv() 333 if err != nil { 334 return nil, err 335 } 336 return res, nil 337 }