github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/protocol/validation/vmcontext.go (about)

     1  package validation
     2  
     3  import (
     4  	"bytes"
     5  
     6  	"github.com/bytom/bytom/consensus/bcrp"
     7  	"github.com/bytom/bytom/consensus/segwit"
     8  	"github.com/bytom/bytom/crypto/sha3pool"
     9  	"github.com/bytom/bytom/errors"
    10  	"github.com/bytom/bytom/protocol/bc"
    11  	"github.com/bytom/bytom/protocol/vm"
    12  )
    13  
    14  // NewTxVMContext generates the vm.Context for BVM
    15  func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, stateData [][]byte, args [][]byte) *vm.Context {
    16  	var (
    17  		tx          = vs.tx
    18  		blockHeight = vs.block.BlockHeader.GetHeight()
    19  		numResults  = uint64(len(tx.ResultIds))
    20  		entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
    21  
    22  		assetID       *[]byte
    23  		amount        *uint64
    24  		destPos       *uint64
    25  		spentOutputID *[]byte
    26  	)
    27  
    28  	switch e := entry.(type) {
    29  	case *bc.Issuance:
    30  		a1 := e.Value.AssetId.Bytes()
    31  		assetID = &a1
    32  		amount = &e.Value.Amount
    33  		destPos = &e.WitnessDestination.Position
    34  
    35  	case *bc.Spend:
    36  		spentOutput := tx.Entries[*e.SpentOutputId].(*bc.OriginalOutput)
    37  		a1 := spentOutput.Source.Value.AssetId.Bytes()
    38  		assetID = &a1
    39  		amount = &spentOutput.Source.Value.Amount
    40  		destPos = &e.WitnessDestination.Position
    41  		s := e.SpentOutputId.Bytes()
    42  		spentOutputID = &s
    43  	}
    44  
    45  	var txSigHash *[]byte
    46  	txSigHashFn := func() []byte {
    47  		if txSigHash == nil {
    48  			hasher := sha3pool.Get256()
    49  			defer sha3pool.Put256(hasher)
    50  
    51  			entryID.WriteTo(hasher)
    52  			tx.ID.WriteTo(hasher)
    53  
    54  			var hash bc.Hash
    55  			hash.ReadFrom(hasher)
    56  			hashBytes := hash.Bytes()
    57  			txSigHash = &hashBytes
    58  		}
    59  		return *txSigHash
    60  	}
    61  
    62  	ec := &entryContext{
    63  		entry:   entry,
    64  		entries: tx.Entries,
    65  	}
    66  
    67  	result := &vm.Context{
    68  		VMVersion: prog.VmVersion,
    69  		Code:      convertProgram(prog.Code, vs.converter),
    70  		StateData: stateData,
    71  		Arguments: args,
    72  
    73  		EntryID: entryID.Bytes(),
    74  
    75  		TxVersion:   &tx.Version,
    76  		BlockHeight: &blockHeight,
    77  
    78  		TxSigHash:     txSigHashFn,
    79  		NumResults:    &numResults,
    80  		AssetID:       assetID,
    81  		Amount:        amount,
    82  		DestPos:       destPos,
    83  		SpentOutputID: spentOutputID,
    84  		CheckOutput:   ec.checkOutput,
    85  	}
    86  
    87  	return result
    88  }
    89  
    90  func convertProgram(prog []byte, converter ProgramConverterFunc) []byte {
    91  	if segwit.IsP2WPKHScript(prog) {
    92  		if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
    93  			return witnessProg
    94  		}
    95  	} else if segwit.IsP2WSHScript(prog) {
    96  		if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
    97  			return witnessProg
    98  		}
    99  	} else if bcrp.IsCallContractScript(prog) {
   100  		if contractProg, err := converter(prog); err == nil {
   101  			return contractProg
   102  		}
   103  	}
   104  	return prog
   105  }
   106  
   107  type entryContext struct {
   108  	entry   bc.Entry
   109  	entries map[bc.Hash]bc.Entry
   110  }
   111  
   112  func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, state [][]byte, expansion bool) (bool, error) {
   113  	checkEntry := func(e bc.Entry) (bool, error) {
   114  		check := func(prog *bc.Program, value *bc.AssetAmount, stateData [][]byte) bool {
   115  			return (prog.VmVersion == vmVersion &&
   116  				bytes.Equal(prog.Code, code) &&
   117  				bytes.Equal(value.AssetId.Bytes(), assetID) &&
   118  				value.Amount == amount &&
   119  				bytesEqual(stateData, state))
   120  		}
   121  
   122  		switch e := e.(type) {
   123  		case *bc.OriginalOutput:
   124  			return check(e.ControlProgram, e.Source.Value, e.StateData), nil
   125  
   126  		case *bc.VoteOutput:
   127  			return check(e.ControlProgram, e.Source.Value, e.StateData), nil
   128  
   129  		case *bc.Retirement:
   130  			var prog bc.Program
   131  			if expansion {
   132  				// The spec requires prog.Code to be the empty string only
   133  				// when !expansion. When expansion is true, we prepopulate
   134  				// prog.Code to give check() a freebie match.
   135  				//
   136  				// (The spec always requires prog.VmVersion to be zero.)
   137  				prog.Code = code
   138  			}
   139  			return check(&prog, e.Source.Value, [][]byte{}), nil
   140  		}
   141  
   142  		return false, vm.ErrContext
   143  	}
   144  
   145  	checkMux := func(m *bc.Mux) (bool, error) {
   146  		if index >= uint64(len(m.WitnessDestinations)) {
   147  			return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
   148  		}
   149  		eID := m.WitnessDestinations[index].Ref
   150  		e, ok := ec.entries[*eID]
   151  		if !ok {
   152  			return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
   153  		}
   154  		return checkEntry(e)
   155  	}
   156  
   157  	switch e := ec.entry.(type) {
   158  	case *bc.Mux:
   159  		return checkMux(e)
   160  
   161  	case *bc.Issuance:
   162  		d, ok := ec.entries[*e.WitnessDestination.Ref]
   163  		if !ok {
   164  			return false, errors.Wrapf(bc.ErrMissingEntry, "entry for issuance destination %x not found", e.WitnessDestination.Ref.Bytes())
   165  		}
   166  		if m, ok := d.(*bc.Mux); ok {
   167  			return checkMux(m)
   168  		}
   169  		if index != 0 {
   170  			return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
   171  		}
   172  		return checkEntry(d)
   173  
   174  	case *bc.Spend:
   175  		d, ok := ec.entries[*e.WitnessDestination.Ref]
   176  		if !ok {
   177  			return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
   178  		}
   179  		if m, ok := d.(*bc.Mux); ok {
   180  			return checkMux(m)
   181  		}
   182  		if index != 0 {
   183  			return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
   184  		}
   185  		return checkEntry(d)
   186  
   187  	case *bc.VetoInput:
   188  		d, ok := ec.entries[*e.WitnessDestination.Ref]
   189  		if !ok {
   190  			return false, errors.Wrapf(bc.ErrMissingEntry, "entry for vetoInput destination %x not found", e.WitnessDestination.Ref.Bytes())
   191  		}
   192  		if m, ok := d.(*bc.Mux); ok {
   193  			return checkMux(m)
   194  		}
   195  		if index != 0 {
   196  			return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
   197  		}
   198  		return checkEntry(d)
   199  
   200  	}
   201  
   202  	return false, vm.ErrContext
   203  }
   204  
   205  func bytesEqual(a, b [][]byte) bool {
   206  	if (a == nil) != (b == nil) {
   207  		return false
   208  	}
   209  
   210  	if len(a) != len(b) {
   211  		return false
   212  	}
   213  
   214  	for i, v := range a {
   215  		if !bytes.Equal(v, b[i]) {
   216  			return false
   217  		}
   218  	}
   219  
   220  	return true
   221  }