github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/privval/file.go (about)

     1  package privval
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/cosmos/gogoproto/proto"
    11  
    12  	"github.com/badrootd/nibiru-cometbft/crypto"
    13  	"github.com/badrootd/nibiru-cometbft/crypto/ed25519"
    14  	cmtbytes "github.com/badrootd/nibiru-cometbft/libs/bytes"
    15  	cmtjson "github.com/badrootd/nibiru-cometbft/libs/json"
    16  	cmtos "github.com/badrootd/nibiru-cometbft/libs/os"
    17  	"github.com/badrootd/nibiru-cometbft/libs/protoio"
    18  	"github.com/badrootd/nibiru-cometbft/libs/tempfile"
    19  	cmtproto "github.com/badrootd/nibiru-cometbft/proto/tendermint/types"
    20  	"github.com/badrootd/nibiru-cometbft/types"
    21  	cmttime "github.com/badrootd/nibiru-cometbft/types/time"
    22  )
    23  
    24  // TODO: type ?
    25  const (
    26  	stepNone      int8 = 0 // Used to distinguish the initial state
    27  	stepPropose   int8 = 1
    28  	stepPrevote   int8 = 2
    29  	stepPrecommit int8 = 3
    30  )
    31  
    32  // A vote is either stepPrevote or stepPrecommit.
    33  func voteToStep(vote *cmtproto.Vote) int8 {
    34  	switch vote.Type {
    35  	case cmtproto.PrevoteType:
    36  		return stepPrevote
    37  	case cmtproto.PrecommitType:
    38  		return stepPrecommit
    39  	default:
    40  		panic(fmt.Sprintf("Unknown vote type: %v", vote.Type))
    41  	}
    42  }
    43  
    44  //-------------------------------------------------------------------------------
    45  
    46  // FilePVKey stores the immutable part of PrivValidator.
    47  type FilePVKey struct {
    48  	Address types.Address  `json:"address"`
    49  	PubKey  crypto.PubKey  `json:"pub_key"`
    50  	PrivKey crypto.PrivKey `json:"priv_key"`
    51  
    52  	filePath string
    53  }
    54  
    55  // Save persists the FilePVKey to its filePath.
    56  func (pvKey FilePVKey) Save() {
    57  	outFile := pvKey.filePath
    58  	if outFile == "" {
    59  		panic("cannot save PrivValidator key: filePath not set")
    60  	}
    61  
    62  	jsonBytes, err := cmtjson.MarshalIndent(pvKey, "", "  ")
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  
    67  	if err := tempfile.WriteFileAtomic(outFile, jsonBytes, 0600); err != nil {
    68  		panic(err)
    69  	}
    70  }
    71  
    72  //-------------------------------------------------------------------------------
    73  
    74  // FilePVLastSignState stores the mutable part of PrivValidator.
    75  type FilePVLastSignState struct {
    76  	Height    int64             `json:"height"`
    77  	Round     int32             `json:"round"`
    78  	Step      int8              `json:"step"`
    79  	Signature []byte            `json:"signature,omitempty"`
    80  	SignBytes cmtbytes.HexBytes `json:"signbytes,omitempty"`
    81  
    82  	filePath string
    83  }
    84  
    85  // CheckHRS checks the given height, round, step (HRS) against that of the
    86  // FilePVLastSignState. It returns an error if the arguments constitute a regression,
    87  // or if they match but the SignBytes are empty.
    88  // The returned boolean indicates whether the last Signature should be reused -
    89  // it returns true if the HRS matches the arguments and the SignBytes are not empty (indicating
    90  // we have already signed for this HRS, and can reuse the existing signature).
    91  // It panics if the HRS matches the arguments, there's a SignBytes, but no Signature.
    92  func (lss *FilePVLastSignState) CheckHRS(height int64, round int32, step int8) (bool, error) {
    93  
    94  	if lss.Height > height {
    95  		return false, fmt.Errorf("height regression. Got %v, last height %v", height, lss.Height)
    96  	}
    97  
    98  	if lss.Height == height {
    99  		if lss.Round > round {
   100  			return false, fmt.Errorf("round regression at height %v. Got %v, last round %v", height, round, lss.Round)
   101  		}
   102  
   103  		if lss.Round == round {
   104  			if lss.Step > step {
   105  				return false, fmt.Errorf(
   106  					"step regression at height %v round %v. Got %v, last step %v",
   107  					height,
   108  					round,
   109  					step,
   110  					lss.Step,
   111  				)
   112  			} else if lss.Step == step {
   113  				if lss.SignBytes != nil {
   114  					if lss.Signature == nil {
   115  						panic("pv: Signature is nil but SignBytes is not!")
   116  					}
   117  					return true, nil
   118  				}
   119  				return false, errors.New("no SignBytes found")
   120  			}
   121  		}
   122  	}
   123  	return false, nil
   124  }
   125  
   126  // Save persists the FilePvLastSignState to its filePath.
   127  func (lss *FilePVLastSignState) Save() {
   128  	outFile := lss.filePath
   129  	if outFile == "" {
   130  		panic("cannot save FilePVLastSignState: filePath not set")
   131  	}
   132  	jsonBytes, err := cmtjson.MarshalIndent(lss, "", "  ")
   133  	if err != nil {
   134  		panic(err)
   135  	}
   136  	err = tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
   137  	if err != nil {
   138  		panic(err)
   139  	}
   140  }
   141  
   142  //-------------------------------------------------------------------------------
   143  
   144  // FilePV implements PrivValidator using data persisted to disk
   145  // to prevent double signing.
   146  // NOTE: the directories containing pv.Key.filePath and pv.LastSignState.filePath must already exist.
   147  // It includes the LastSignature and LastSignBytes so we don't lose the signature
   148  // if the process crashes after signing but before the resulting consensus message is processed.
   149  type FilePV struct {
   150  	Key           FilePVKey
   151  	LastSignState FilePVLastSignState
   152  }
   153  
   154  // NewFilePV generates a new validator from the given key and paths.
   155  func NewFilePV(privKey crypto.PrivKey, keyFilePath, stateFilePath string) *FilePV {
   156  	return &FilePV{
   157  		Key: FilePVKey{
   158  			Address:  privKey.PubKey().Address(),
   159  			PubKey:   privKey.PubKey(),
   160  			PrivKey:  privKey,
   161  			filePath: keyFilePath,
   162  		},
   163  		LastSignState: FilePVLastSignState{
   164  			Step:     stepNone,
   165  			filePath: stateFilePath,
   166  		},
   167  	}
   168  }
   169  
   170  // GenFilePV generates a new validator with randomly generated private key
   171  // and sets the filePaths, but does not call Save().
   172  func GenFilePV(keyFilePath, stateFilePath string) *FilePV {
   173  	return NewFilePV(ed25519.GenPrivKey(), keyFilePath, stateFilePath)
   174  }
   175  
   176  // LoadFilePV loads a FilePV from the filePaths.  The FilePV handles double
   177  // signing prevention by persisting data to the stateFilePath.  If either file path
   178  // does not exist, the program will exit.
   179  func LoadFilePV(keyFilePath, stateFilePath string) *FilePV {
   180  	return loadFilePV(keyFilePath, stateFilePath, true)
   181  }
   182  
   183  // LoadFilePVEmptyState loads a FilePV from the given keyFilePath, with an empty LastSignState.
   184  // If the keyFilePath does not exist, the program will exit.
   185  func LoadFilePVEmptyState(keyFilePath, stateFilePath string) *FilePV {
   186  	return loadFilePV(keyFilePath, stateFilePath, false)
   187  }
   188  
   189  // If loadState is true, we load from the stateFilePath. Otherwise, we use an empty LastSignState.
   190  func loadFilePV(keyFilePath, stateFilePath string, loadState bool) *FilePV {
   191  	keyJSONBytes, err := os.ReadFile(keyFilePath)
   192  	if err != nil {
   193  		cmtos.Exit(err.Error())
   194  	}
   195  	pvKey := FilePVKey{}
   196  	err = cmtjson.Unmarshal(keyJSONBytes, &pvKey)
   197  	if err != nil {
   198  		cmtos.Exit(fmt.Sprintf("Error reading PrivValidator key from %v: %v\n", keyFilePath, err))
   199  	}
   200  
   201  	// overwrite pubkey and address for convenience
   202  	pvKey.PubKey = pvKey.PrivKey.PubKey()
   203  	pvKey.Address = pvKey.PubKey.Address()
   204  	pvKey.filePath = keyFilePath
   205  
   206  	pvState := FilePVLastSignState{}
   207  
   208  	if loadState {
   209  		stateJSONBytes, err := os.ReadFile(stateFilePath)
   210  		if err != nil {
   211  			cmtos.Exit(err.Error())
   212  		}
   213  		err = cmtjson.Unmarshal(stateJSONBytes, &pvState)
   214  		if err != nil {
   215  			cmtos.Exit(fmt.Sprintf("Error reading PrivValidator state from %v: %v\n", stateFilePath, err))
   216  		}
   217  	}
   218  
   219  	pvState.filePath = stateFilePath
   220  
   221  	return &FilePV{
   222  		Key:           pvKey,
   223  		LastSignState: pvState,
   224  	}
   225  }
   226  
   227  // LoadOrGenFilePV loads a FilePV from the given filePaths
   228  // or else generates a new one and saves it to the filePaths.
   229  func LoadOrGenFilePV(keyFilePath, stateFilePath string) *FilePV {
   230  	var pv *FilePV
   231  	if cmtos.FileExists(keyFilePath) {
   232  		pv = LoadFilePV(keyFilePath, stateFilePath)
   233  	} else {
   234  		pv = GenFilePV(keyFilePath, stateFilePath)
   235  		pv.Save()
   236  	}
   237  	return pv
   238  }
   239  
   240  // GetAddress returns the address of the validator.
   241  // Implements PrivValidator.
   242  func (pv *FilePV) GetAddress() types.Address {
   243  	return pv.Key.Address
   244  }
   245  
   246  // GetPubKey returns the public key of the validator.
   247  // Implements PrivValidator.
   248  func (pv *FilePV) GetPubKey() (crypto.PubKey, error) {
   249  	return pv.Key.PubKey, nil
   250  }
   251  
   252  // SignVote signs a canonical representation of the vote, along with the
   253  // chainID. Implements PrivValidator.
   254  func (pv *FilePV) SignVote(chainID string, vote *cmtproto.Vote) error {
   255  	if err := pv.signVote(chainID, vote); err != nil {
   256  		return fmt.Errorf("error signing vote: %v", err)
   257  	}
   258  	return nil
   259  }
   260  
   261  // SignProposal signs a canonical representation of the proposal, along with
   262  // the chainID. Implements PrivValidator.
   263  func (pv *FilePV) SignProposal(chainID string, proposal *cmtproto.Proposal) error {
   264  	if err := pv.signProposal(chainID, proposal); err != nil {
   265  		return fmt.Errorf("error signing proposal: %v", err)
   266  	}
   267  	return nil
   268  }
   269  
   270  // Save persists the FilePV to disk.
   271  func (pv *FilePV) Save() {
   272  	pv.Key.Save()
   273  	pv.LastSignState.Save()
   274  }
   275  
   276  // Reset resets all fields in the FilePV.
   277  // NOTE: Unsafe!
   278  func (pv *FilePV) Reset() {
   279  	var sig []byte
   280  	pv.LastSignState.Height = 0
   281  	pv.LastSignState.Round = 0
   282  	pv.LastSignState.Step = 0
   283  	pv.LastSignState.Signature = sig
   284  	pv.LastSignState.SignBytes = nil
   285  	pv.Save()
   286  }
   287  
   288  // String returns a string representation of the FilePV.
   289  func (pv *FilePV) String() string {
   290  	return fmt.Sprintf(
   291  		"PrivValidator{%v LH:%v, LR:%v, LS:%v}",
   292  		pv.GetAddress(),
   293  		pv.LastSignState.Height,
   294  		pv.LastSignState.Round,
   295  		pv.LastSignState.Step,
   296  	)
   297  }
   298  
   299  //------------------------------------------------------------------------------------
   300  
   301  // signVote checks if the vote is good to sign and sets the vote signature.
   302  // It may need to set the timestamp as well if the vote is otherwise the same as
   303  // a previously signed vote (ie. we crashed after signing but before the vote hit the WAL).
   304  func (pv *FilePV) signVote(chainID string, vote *cmtproto.Vote) error {
   305  	height, round, step := vote.Height, vote.Round, voteToStep(vote)
   306  
   307  	lss := pv.LastSignState
   308  
   309  	sameHRS, err := lss.CheckHRS(height, round, step)
   310  	if err != nil {
   311  		return err
   312  	}
   313  
   314  	signBytes := types.VoteSignBytes(chainID, vote)
   315  
   316  	// We might crash before writing to the wal,
   317  	// causing us to try to re-sign for the same HRS.
   318  	// If signbytes are the same, use the last signature.
   319  	// If they only differ by timestamp, use last timestamp and signature
   320  	// Otherwise, return error
   321  	if sameHRS {
   322  		if bytes.Equal(signBytes, lss.SignBytes) {
   323  			vote.Signature = lss.Signature
   324  		} else if timestamp, ok := checkVotesOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
   325  			vote.Timestamp = timestamp
   326  			vote.Signature = lss.Signature
   327  		} else {
   328  			err = fmt.Errorf("conflicting data")
   329  		}
   330  		return err
   331  	}
   332  
   333  	// It passed the checks. Sign the vote
   334  	sig, err := pv.Key.PrivKey.Sign(signBytes)
   335  	if err != nil {
   336  		return err
   337  	}
   338  	pv.saveSigned(height, round, step, signBytes, sig)
   339  	vote.Signature = sig
   340  	return nil
   341  }
   342  
   343  // signProposal checks if the proposal is good to sign and sets the proposal signature.
   344  // It may need to set the timestamp as well if the proposal is otherwise the same as
   345  // a previously signed proposal ie. we crashed after signing but before the proposal hit the WAL).
   346  func (pv *FilePV) signProposal(chainID string, proposal *cmtproto.Proposal) error {
   347  	height, round, step := proposal.Height, proposal.Round, stepPropose
   348  
   349  	lss := pv.LastSignState
   350  
   351  	sameHRS, err := lss.CheckHRS(height, round, step)
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	signBytes := types.ProposalSignBytes(chainID, proposal)
   357  
   358  	// We might crash before writing to the wal,
   359  	// causing us to try to re-sign for the same HRS.
   360  	// If signbytes are the same, use the last signature.
   361  	// If they only differ by timestamp, use last timestamp and signature
   362  	// Otherwise, return error
   363  	if sameHRS {
   364  		if bytes.Equal(signBytes, lss.SignBytes) {
   365  			proposal.Signature = lss.Signature
   366  		} else if timestamp, ok := checkProposalsOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
   367  			proposal.Timestamp = timestamp
   368  			proposal.Signature = lss.Signature
   369  		} else {
   370  			err = fmt.Errorf("conflicting data")
   371  		}
   372  		return err
   373  	}
   374  
   375  	// It passed the checks. Sign the proposal
   376  	sig, err := pv.Key.PrivKey.Sign(signBytes)
   377  	if err != nil {
   378  		return err
   379  	}
   380  	pv.saveSigned(height, round, step, signBytes, sig)
   381  	proposal.Signature = sig
   382  	return nil
   383  }
   384  
   385  // Persist height/round/step and signature
   386  func (pv *FilePV) saveSigned(height int64, round int32, step int8,
   387  	signBytes []byte, sig []byte) {
   388  
   389  	pv.LastSignState.Height = height
   390  	pv.LastSignState.Round = round
   391  	pv.LastSignState.Step = step
   392  	pv.LastSignState.Signature = sig
   393  	pv.LastSignState.SignBytes = signBytes
   394  	pv.LastSignState.Save()
   395  }
   396  
   397  //-----------------------------------------------------------------------------------------
   398  
   399  // returns the timestamp from the lastSignBytes.
   400  // returns true if the only difference in the votes is their timestamp.
   401  func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
   402  	var lastVote, newVote cmtproto.CanonicalVote
   403  	if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil {
   404  		panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into vote: %v", err))
   405  	}
   406  	if err := protoio.UnmarshalDelimited(newSignBytes, &newVote); err != nil {
   407  		panic(fmt.Sprintf("signBytes cannot be unmarshalled into vote: %v", err))
   408  	}
   409  
   410  	lastTime := lastVote.Timestamp
   411  	// set the times to the same value and check equality
   412  	now := cmttime.Now()
   413  	lastVote.Timestamp = now
   414  	newVote.Timestamp = now
   415  
   416  	return lastTime, proto.Equal(&newVote, &lastVote)
   417  }
   418  
   419  // returns the timestamp from the lastSignBytes.
   420  // returns true if the only difference in the proposals is their timestamp
   421  func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
   422  	var lastProposal, newProposal cmtproto.CanonicalProposal
   423  	if err := protoio.UnmarshalDelimited(lastSignBytes, &lastProposal); err != nil {
   424  		panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into proposal: %v", err))
   425  	}
   426  	if err := protoio.UnmarshalDelimited(newSignBytes, &newProposal); err != nil {
   427  		panic(fmt.Sprintf("signBytes cannot be unmarshalled into proposal: %v", err))
   428  	}
   429  
   430  	lastTime := lastProposal.Timestamp
   431  	// set the times to the same value and check equality
   432  	now := cmttime.Now()
   433  	lastProposal.Timestamp = now
   434  	newProposal.Timestamp = now
   435  
   436  	return lastTime, proto.Equal(&newProposal, &lastProposal)
   437  }