github.com/MetalBlockchain/metalgo@v1.11.9/vms/platformvm/state/diff.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package state
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/MetalBlockchain/metalgo/database"
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/metalgo/vms/components/avax"
    14  	"github.com/MetalBlockchain/metalgo/vms/platformvm/fx"
    15  	"github.com/MetalBlockchain/metalgo/vms/platformvm/status"
    16  	"github.com/MetalBlockchain/metalgo/vms/platformvm/txs"
    17  )
    18  
    19  var (
    20  	_ Diff     = (*diff)(nil)
    21  	_ Versions = stateGetter{}
    22  
    23  	ErrMissingParentState = errors.New("missing parent state")
    24  )
    25  
    26  type Diff interface {
    27  	Chain
    28  
    29  	Apply(Chain) error
    30  }
    31  
    32  type diff struct {
    33  	parentID      ids.ID
    34  	stateVersions Versions
    35  
    36  	timestamp time.Time
    37  
    38  	// Subnet ID --> supply of native asset of the subnet
    39  	currentSupply map[ids.ID]uint64
    40  
    41  	currentStakerDiffs diffStakers
    42  	// map of subnetID -> nodeID -> total accrued delegatee rewards
    43  	modifiedDelegateeRewards map[ids.ID]map[ids.NodeID]uint64
    44  	pendingStakerDiffs       diffStakers
    45  
    46  	addedSubnetIDs []ids.ID
    47  	// Subnet ID --> Owner of the subnet
    48  	subnetOwners map[ids.ID]fx.Owner
    49  	// Subnet ID --> Tx that transforms the subnet
    50  	transformedSubnets map[ids.ID]*txs.Tx
    51  
    52  	addedChains map[ids.ID][]*txs.Tx
    53  
    54  	addedRewardUTXOs map[ids.ID][]*avax.UTXO
    55  
    56  	addedTxs map[ids.ID]*txAndStatus
    57  
    58  	// map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed
    59  	modifiedUTXOs map[ids.ID]*avax.UTXO
    60  }
    61  
    62  func NewDiff(
    63  	parentID ids.ID,
    64  	stateVersions Versions,
    65  ) (Diff, error) {
    66  	parentState, ok := stateVersions.GetState(parentID)
    67  	if !ok {
    68  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, parentID)
    69  	}
    70  	return &diff{
    71  		parentID:      parentID,
    72  		stateVersions: stateVersions,
    73  		timestamp:     parentState.GetTimestamp(),
    74  		subnetOwners:  make(map[ids.ID]fx.Owner),
    75  	}, nil
    76  }
    77  
    78  type stateGetter struct {
    79  	state Chain
    80  }
    81  
    82  func (s stateGetter) GetState(ids.ID) (Chain, bool) {
    83  	return s.state, true
    84  }
    85  
    86  func NewDiffOn(parentState Chain) (Diff, error) {
    87  	return NewDiff(ids.Empty, stateGetter{
    88  		state: parentState,
    89  	})
    90  }
    91  
    92  func (d *diff) GetTimestamp() time.Time {
    93  	return d.timestamp
    94  }
    95  
    96  func (d *diff) SetTimestamp(timestamp time.Time) {
    97  	d.timestamp = timestamp
    98  }
    99  
   100  func (d *diff) GetCurrentSupply(subnetID ids.ID) (uint64, error) {
   101  	supply, ok := d.currentSupply[subnetID]
   102  	if ok {
   103  		return supply, nil
   104  	}
   105  
   106  	// If the subnet supply wasn't modified in this diff, ask the parent state.
   107  	parentState, ok := d.stateVersions.GetState(d.parentID)
   108  	if !ok {
   109  		return 0, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   110  	}
   111  	return parentState.GetCurrentSupply(subnetID)
   112  }
   113  
   114  func (d *diff) SetCurrentSupply(subnetID ids.ID, currentSupply uint64) {
   115  	if d.currentSupply == nil {
   116  		d.currentSupply = map[ids.ID]uint64{
   117  			subnetID: currentSupply,
   118  		}
   119  	} else {
   120  		d.currentSupply[subnetID] = currentSupply
   121  	}
   122  }
   123  
   124  func (d *diff) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) {
   125  	// If the validator was modified in this diff, return the modified
   126  	// validator.
   127  	newValidator, status := d.currentStakerDiffs.GetValidator(subnetID, nodeID)
   128  	switch status {
   129  	case added:
   130  		return newValidator, nil
   131  	case deleted:
   132  		return nil, database.ErrNotFound
   133  	default:
   134  		// If the validator wasn't modified in this diff, ask the parent state.
   135  		parentState, ok := d.stateVersions.GetState(d.parentID)
   136  		if !ok {
   137  			return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   138  		}
   139  		return parentState.GetCurrentValidator(subnetID, nodeID)
   140  	}
   141  }
   142  
   143  func (d *diff) SetDelegateeReward(subnetID ids.ID, nodeID ids.NodeID, amount uint64) error {
   144  	if d.modifiedDelegateeRewards == nil {
   145  		d.modifiedDelegateeRewards = make(map[ids.ID]map[ids.NodeID]uint64)
   146  	}
   147  	nodes, ok := d.modifiedDelegateeRewards[subnetID]
   148  	if !ok {
   149  		nodes = make(map[ids.NodeID]uint64)
   150  		d.modifiedDelegateeRewards[subnetID] = nodes
   151  	}
   152  	nodes[nodeID] = amount
   153  	return nil
   154  }
   155  
   156  func (d *diff) GetDelegateeReward(subnetID ids.ID, nodeID ids.NodeID) (uint64, error) {
   157  	amount, modified := d.modifiedDelegateeRewards[subnetID][nodeID]
   158  	if modified {
   159  		return amount, nil
   160  	}
   161  	parentState, ok := d.stateVersions.GetState(d.parentID)
   162  	if !ok {
   163  		return 0, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   164  	}
   165  	return parentState.GetDelegateeReward(subnetID, nodeID)
   166  }
   167  
   168  func (d *diff) PutCurrentValidator(staker *Staker) {
   169  	d.currentStakerDiffs.PutValidator(staker)
   170  }
   171  
   172  func (d *diff) DeleteCurrentValidator(staker *Staker) {
   173  	d.currentStakerDiffs.DeleteValidator(staker)
   174  }
   175  
   176  func (d *diff) GetCurrentDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (StakerIterator, error) {
   177  	parentState, ok := d.stateVersions.GetState(d.parentID)
   178  	if !ok {
   179  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   180  	}
   181  
   182  	parentIterator, err := parentState.GetCurrentDelegatorIterator(subnetID, nodeID)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	return d.currentStakerDiffs.GetDelegatorIterator(parentIterator, subnetID, nodeID), nil
   188  }
   189  
   190  func (d *diff) PutCurrentDelegator(staker *Staker) {
   191  	d.currentStakerDiffs.PutDelegator(staker)
   192  }
   193  
   194  func (d *diff) DeleteCurrentDelegator(staker *Staker) {
   195  	d.currentStakerDiffs.DeleteDelegator(staker)
   196  }
   197  
   198  func (d *diff) GetCurrentStakerIterator() (StakerIterator, error) {
   199  	parentState, ok := d.stateVersions.GetState(d.parentID)
   200  	if !ok {
   201  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   202  	}
   203  
   204  	parentIterator, err := parentState.GetCurrentStakerIterator()
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	return d.currentStakerDiffs.GetStakerIterator(parentIterator), nil
   210  }
   211  
   212  func (d *diff) GetPendingValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) {
   213  	// If the validator was modified in this diff, return the modified
   214  	// validator.
   215  	newValidator, status := d.pendingStakerDiffs.GetValidator(subnetID, nodeID)
   216  	switch status {
   217  	case added:
   218  		return newValidator, nil
   219  	case deleted:
   220  		return nil, database.ErrNotFound
   221  	default:
   222  		// If the validator wasn't modified in this diff, ask the parent state.
   223  		parentState, ok := d.stateVersions.GetState(d.parentID)
   224  		if !ok {
   225  			return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   226  		}
   227  		return parentState.GetPendingValidator(subnetID, nodeID)
   228  	}
   229  }
   230  
   231  func (d *diff) PutPendingValidator(staker *Staker) {
   232  	d.pendingStakerDiffs.PutValidator(staker)
   233  }
   234  
   235  func (d *diff) DeletePendingValidator(staker *Staker) {
   236  	d.pendingStakerDiffs.DeleteValidator(staker)
   237  }
   238  
   239  func (d *diff) GetPendingDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (StakerIterator, error) {
   240  	parentState, ok := d.stateVersions.GetState(d.parentID)
   241  	if !ok {
   242  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   243  	}
   244  
   245  	parentIterator, err := parentState.GetPendingDelegatorIterator(subnetID, nodeID)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	return d.pendingStakerDiffs.GetDelegatorIterator(parentIterator, subnetID, nodeID), nil
   251  }
   252  
   253  func (d *diff) PutPendingDelegator(staker *Staker) {
   254  	d.pendingStakerDiffs.PutDelegator(staker)
   255  }
   256  
   257  func (d *diff) DeletePendingDelegator(staker *Staker) {
   258  	d.pendingStakerDiffs.DeleteDelegator(staker)
   259  }
   260  
   261  func (d *diff) GetPendingStakerIterator() (StakerIterator, error) {
   262  	parentState, ok := d.stateVersions.GetState(d.parentID)
   263  	if !ok {
   264  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   265  	}
   266  
   267  	parentIterator, err := parentState.GetPendingStakerIterator()
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	return d.pendingStakerDiffs.GetStakerIterator(parentIterator), nil
   273  }
   274  
   275  func (d *diff) AddSubnet(subnetID ids.ID) {
   276  	d.addedSubnetIDs = append(d.addedSubnetIDs, subnetID)
   277  }
   278  
   279  func (d *diff) GetSubnetOwner(subnetID ids.ID) (fx.Owner, error) {
   280  	owner, exists := d.subnetOwners[subnetID]
   281  	if exists {
   282  		return owner, nil
   283  	}
   284  
   285  	// If the subnet owner was not assigned in this diff, ask the parent state.
   286  	parentState, ok := d.stateVersions.GetState(d.parentID)
   287  	if !ok {
   288  		return nil, ErrMissingParentState
   289  	}
   290  	return parentState.GetSubnetOwner(subnetID)
   291  }
   292  
   293  func (d *diff) SetSubnetOwner(subnetID ids.ID, owner fx.Owner) {
   294  	d.subnetOwners[subnetID] = owner
   295  }
   296  
   297  func (d *diff) GetSubnetTransformation(subnetID ids.ID) (*txs.Tx, error) {
   298  	tx, exists := d.transformedSubnets[subnetID]
   299  	if exists {
   300  		return tx, nil
   301  	}
   302  
   303  	// If the subnet wasn't transformed in this diff, ask the parent state.
   304  	parentState, ok := d.stateVersions.GetState(d.parentID)
   305  	if !ok {
   306  		return nil, ErrMissingParentState
   307  	}
   308  	return parentState.GetSubnetTransformation(subnetID)
   309  }
   310  
   311  func (d *diff) AddSubnetTransformation(transformSubnetTxIntf *txs.Tx) {
   312  	transformSubnetTx := transformSubnetTxIntf.Unsigned.(*txs.TransformSubnetTx)
   313  	if d.transformedSubnets == nil {
   314  		d.transformedSubnets = map[ids.ID]*txs.Tx{
   315  			transformSubnetTx.Subnet: transformSubnetTxIntf,
   316  		}
   317  	} else {
   318  		d.transformedSubnets[transformSubnetTx.Subnet] = transformSubnetTxIntf
   319  	}
   320  }
   321  
   322  func (d *diff) AddChain(createChainTx *txs.Tx) {
   323  	tx := createChainTx.Unsigned.(*txs.CreateChainTx)
   324  	if d.addedChains == nil {
   325  		d.addedChains = map[ids.ID][]*txs.Tx{
   326  			tx.SubnetID: {createChainTx},
   327  		}
   328  	} else {
   329  		d.addedChains[tx.SubnetID] = append(d.addedChains[tx.SubnetID], createChainTx)
   330  	}
   331  }
   332  
   333  func (d *diff) GetTx(txID ids.ID) (*txs.Tx, status.Status, error) {
   334  	if tx, exists := d.addedTxs[txID]; exists {
   335  		return tx.tx, tx.status, nil
   336  	}
   337  
   338  	parentState, ok := d.stateVersions.GetState(d.parentID)
   339  	if !ok {
   340  		return nil, status.Unknown, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   341  	}
   342  	return parentState.GetTx(txID)
   343  }
   344  
   345  func (d *diff) AddTx(tx *txs.Tx, status status.Status) {
   346  	txID := tx.ID()
   347  	txStatus := &txAndStatus{
   348  		tx:     tx,
   349  		status: status,
   350  	}
   351  	if d.addedTxs == nil {
   352  		d.addedTxs = map[ids.ID]*txAndStatus{
   353  			txID: txStatus,
   354  		}
   355  	} else {
   356  		d.addedTxs[txID] = txStatus
   357  	}
   358  }
   359  
   360  func (d *diff) AddRewardUTXO(txID ids.ID, utxo *avax.UTXO) {
   361  	if d.addedRewardUTXOs == nil {
   362  		d.addedRewardUTXOs = make(map[ids.ID][]*avax.UTXO)
   363  	}
   364  	d.addedRewardUTXOs[txID] = append(d.addedRewardUTXOs[txID], utxo)
   365  }
   366  
   367  func (d *diff) GetUTXO(utxoID ids.ID) (*avax.UTXO, error) {
   368  	utxo, modified := d.modifiedUTXOs[utxoID]
   369  	if !modified {
   370  		parentState, ok := d.stateVersions.GetState(d.parentID)
   371  		if !ok {
   372  			return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   373  		}
   374  		return parentState.GetUTXO(utxoID)
   375  	}
   376  	if utxo == nil {
   377  		return nil, database.ErrNotFound
   378  	}
   379  	return utxo, nil
   380  }
   381  
   382  func (d *diff) AddUTXO(utxo *avax.UTXO) {
   383  	if d.modifiedUTXOs == nil {
   384  		d.modifiedUTXOs = map[ids.ID]*avax.UTXO{
   385  			utxo.InputID(): utxo,
   386  		}
   387  	} else {
   388  		d.modifiedUTXOs[utxo.InputID()] = utxo
   389  	}
   390  }
   391  
   392  func (d *diff) DeleteUTXO(utxoID ids.ID) {
   393  	if d.modifiedUTXOs == nil {
   394  		d.modifiedUTXOs = map[ids.ID]*avax.UTXO{
   395  			utxoID: nil,
   396  		}
   397  	} else {
   398  		d.modifiedUTXOs[utxoID] = nil
   399  	}
   400  }
   401  
   402  func (d *diff) Apply(baseState Chain) error {
   403  	baseState.SetTimestamp(d.timestamp)
   404  	for subnetID, supply := range d.currentSupply {
   405  		baseState.SetCurrentSupply(subnetID, supply)
   406  	}
   407  	for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs {
   408  		for _, validatorDiff := range subnetValidatorDiffs {
   409  			switch validatorDiff.validatorStatus {
   410  			case added:
   411  				baseState.PutCurrentValidator(validatorDiff.validator)
   412  			case deleted:
   413  				baseState.DeleteCurrentValidator(validatorDiff.validator)
   414  			}
   415  
   416  			addedDelegatorIterator := NewTreeIterator(validatorDiff.addedDelegators)
   417  			for addedDelegatorIterator.Next() {
   418  				baseState.PutCurrentDelegator(addedDelegatorIterator.Value())
   419  			}
   420  			addedDelegatorIterator.Release()
   421  
   422  			for _, delegator := range validatorDiff.deletedDelegators {
   423  				baseState.DeleteCurrentDelegator(delegator)
   424  			}
   425  		}
   426  	}
   427  	for subnetID, nodes := range d.modifiedDelegateeRewards {
   428  		for nodeID, amount := range nodes {
   429  			if err := baseState.SetDelegateeReward(subnetID, nodeID, amount); err != nil {
   430  				return err
   431  			}
   432  		}
   433  	}
   434  	for _, subnetValidatorDiffs := range d.pendingStakerDiffs.validatorDiffs {
   435  		for _, validatorDiff := range subnetValidatorDiffs {
   436  			switch validatorDiff.validatorStatus {
   437  			case added:
   438  				baseState.PutPendingValidator(validatorDiff.validator)
   439  			case deleted:
   440  				baseState.DeletePendingValidator(validatorDiff.validator)
   441  			}
   442  
   443  			addedDelegatorIterator := NewTreeIterator(validatorDiff.addedDelegators)
   444  			for addedDelegatorIterator.Next() {
   445  				baseState.PutPendingDelegator(addedDelegatorIterator.Value())
   446  			}
   447  			addedDelegatorIterator.Release()
   448  
   449  			for _, delegator := range validatorDiff.deletedDelegators {
   450  				baseState.DeletePendingDelegator(delegator)
   451  			}
   452  		}
   453  	}
   454  	for _, subnetID := range d.addedSubnetIDs {
   455  		baseState.AddSubnet(subnetID)
   456  	}
   457  	for _, tx := range d.transformedSubnets {
   458  		baseState.AddSubnetTransformation(tx)
   459  	}
   460  	for _, chains := range d.addedChains {
   461  		for _, chain := range chains {
   462  			baseState.AddChain(chain)
   463  		}
   464  	}
   465  	for _, tx := range d.addedTxs {
   466  		baseState.AddTx(tx.tx, tx.status)
   467  	}
   468  	for txID, utxos := range d.addedRewardUTXOs {
   469  		for _, utxo := range utxos {
   470  			baseState.AddRewardUTXO(txID, utxo)
   471  		}
   472  	}
   473  	for utxoID, utxo := range d.modifiedUTXOs {
   474  		if utxo != nil {
   475  			baseState.AddUTXO(utxo)
   476  		} else {
   477  			baseState.DeleteUTXO(utxoID)
   478  		}
   479  	}
   480  	for subnetID, owner := range d.subnetOwners {
   481  		baseState.SetSubnetOwner(subnetID, owner)
   482  	}
   483  	return nil
   484  }