github.com/MetalBlockchain/metalgo@v1.11.9/vms/avm/state/diff.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package state
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/MetalBlockchain/metalgo/database"
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/metalgo/vms/avm/block"
    14  	"github.com/MetalBlockchain/metalgo/vms/avm/txs"
    15  	"github.com/MetalBlockchain/metalgo/vms/components/avax"
    16  )
    17  
    18  var (
    19  	_ Diff     = (*diff)(nil)
    20  	_ Versions = stateGetter{}
    21  
    22  	ErrMissingParentState = errors.New("missing parent state")
    23  )
    24  
    25  type Diff interface {
    26  	Chain
    27  
    28  	Apply(Chain)
    29  }
    30  
    31  type diff struct {
    32  	parentID      ids.ID
    33  	stateVersions Versions
    34  
    35  	// map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed
    36  	modifiedUTXOs map[ids.ID]*avax.UTXO
    37  	addedTxs      map[ids.ID]*txs.Tx     // map of txID -> tx
    38  	addedBlockIDs map[uint64]ids.ID      // map of height -> blockID
    39  	addedBlocks   map[ids.ID]block.Block // map of blockID -> block
    40  
    41  	lastAccepted ids.ID
    42  	timestamp    time.Time
    43  }
    44  
    45  func NewDiff(
    46  	parentID ids.ID,
    47  	stateVersions Versions,
    48  ) (Diff, error) {
    49  	parentState, ok := stateVersions.GetState(parentID)
    50  	if !ok {
    51  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, parentID)
    52  	}
    53  	return &diff{
    54  		parentID:      parentID,
    55  		stateVersions: stateVersions,
    56  		modifiedUTXOs: make(map[ids.ID]*avax.UTXO),
    57  		addedTxs:      make(map[ids.ID]*txs.Tx),
    58  		addedBlockIDs: make(map[uint64]ids.ID),
    59  		addedBlocks:   make(map[ids.ID]block.Block),
    60  		lastAccepted:  parentState.GetLastAccepted(),
    61  		timestamp:     parentState.GetTimestamp(),
    62  	}, nil
    63  }
    64  
    65  type stateGetter struct {
    66  	state Chain
    67  }
    68  
    69  func (s stateGetter) GetState(ids.ID) (Chain, bool) {
    70  	return s.state, true
    71  }
    72  
    73  func NewDiffOn(parentState Chain) (Diff, error) {
    74  	return NewDiff(ids.Empty, stateGetter{
    75  		state: parentState,
    76  	})
    77  }
    78  
    79  func (d *diff) GetUTXO(utxoID ids.ID) (*avax.UTXO, error) {
    80  	if utxo, modified := d.modifiedUTXOs[utxoID]; modified {
    81  		if utxo == nil {
    82  			return nil, database.ErrNotFound
    83  		}
    84  		return utxo, nil
    85  	}
    86  
    87  	parentState, ok := d.stateVersions.GetState(d.parentID)
    88  	if !ok {
    89  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
    90  	}
    91  	return parentState.GetUTXO(utxoID)
    92  }
    93  
    94  func (d *diff) AddUTXO(utxo *avax.UTXO) {
    95  	d.modifiedUTXOs[utxo.InputID()] = utxo
    96  }
    97  
    98  func (d *diff) DeleteUTXO(utxoID ids.ID) {
    99  	d.modifiedUTXOs[utxoID] = nil
   100  }
   101  
   102  func (d *diff) GetTx(txID ids.ID) (*txs.Tx, error) {
   103  	if tx, exists := d.addedTxs[txID]; exists {
   104  		return tx, nil
   105  	}
   106  
   107  	parentState, ok := d.stateVersions.GetState(d.parentID)
   108  	if !ok {
   109  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   110  	}
   111  	return parentState.GetTx(txID)
   112  }
   113  
   114  func (d *diff) AddTx(tx *txs.Tx) {
   115  	d.addedTxs[tx.ID()] = tx
   116  }
   117  
   118  func (d *diff) GetBlockIDAtHeight(height uint64) (ids.ID, error) {
   119  	if blkID, exists := d.addedBlockIDs[height]; exists {
   120  		return blkID, nil
   121  	}
   122  
   123  	parentState, ok := d.stateVersions.GetState(d.parentID)
   124  	if !ok {
   125  		return ids.Empty, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   126  	}
   127  	return parentState.GetBlockIDAtHeight(height)
   128  }
   129  
   130  func (d *diff) GetBlock(blkID ids.ID) (block.Block, error) {
   131  	if blk, exists := d.addedBlocks[blkID]; exists {
   132  		return blk, nil
   133  	}
   134  
   135  	parentState, ok := d.stateVersions.GetState(d.parentID)
   136  	if !ok {
   137  		return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
   138  	}
   139  	return parentState.GetBlock(blkID)
   140  }
   141  
   142  func (d *diff) AddBlock(blk block.Block) {
   143  	blkID := blk.ID()
   144  	d.addedBlockIDs[blk.Height()] = blkID
   145  	d.addedBlocks[blkID] = blk
   146  }
   147  
   148  func (d *diff) GetLastAccepted() ids.ID {
   149  	return d.lastAccepted
   150  }
   151  
   152  func (d *diff) SetLastAccepted(lastAccepted ids.ID) {
   153  	d.lastAccepted = lastAccepted
   154  }
   155  
   156  func (d *diff) GetTimestamp() time.Time {
   157  	return d.timestamp
   158  }
   159  
   160  func (d *diff) SetTimestamp(t time.Time) {
   161  	d.timestamp = t
   162  }
   163  
   164  func (d *diff) Apply(state Chain) {
   165  	for utxoID, utxo := range d.modifiedUTXOs {
   166  		if utxo != nil {
   167  			state.AddUTXO(utxo)
   168  		} else {
   169  			state.DeleteUTXO(utxoID)
   170  		}
   171  	}
   172  
   173  	for _, tx := range d.addedTxs {
   174  		state.AddTx(tx)
   175  	}
   176  
   177  	for _, blk := range d.addedBlocks {
   178  		state.AddBlock(blk)
   179  	}
   180  
   181  	state.SetLastAccepted(d.lastAccepted)
   182  	state.SetTimestamp(d.timestamp)
   183  }