github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/protocol/validation/tx.go (about)

     1  package validation
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"runtime"
     7  	"sync"
     8  
     9  	"github.com/bytom/bytom/consensus"
    10  	"github.com/bytom/bytom/consensus/segwit"
    11  	"github.com/bytom/bytom/errors"
    12  	"github.com/bytom/bytom/math/checked"
    13  	"github.com/bytom/bytom/protocol/bc"
    14  	"github.com/bytom/bytom/protocol/vm"
    15  )
    16  
    17  const ruleAA = 142500
    18  
    19  // validate transaction error
    20  var (
    21  	ErrTxVersion                 = errors.New("invalid transaction version")
    22  	ErrWrongTransactionSize      = errors.New("invalid transaction size")
    23  	ErrBadTimeRange              = errors.New("invalid transaction time range")
    24  	ErrEmptyInputIDs             = errors.New("got the empty InputIDs")
    25  	ErrNotStandardTx             = errors.New("not standard transaction")
    26  	ErrWrongCoinbaseTransaction  = errors.New("wrong coinbase transaction")
    27  	ErrWrongCoinbaseAsset        = errors.New("wrong coinbase assetID")
    28  	ErrCoinbaseArbitraryOversize = errors.New("coinbase arbitrary size is larger than limit")
    29  	ErrEmptyResults              = errors.New("transaction has no results")
    30  	ErrMismatchedAssetID         = errors.New("mismatched assetID")
    31  	ErrMismatchedPosition        = errors.New("mismatched value source/dest position")
    32  	ErrMismatchedReference       = errors.New("mismatched reference")
    33  	ErrMismatchedValue           = errors.New("mismatched value")
    34  	ErrMissingField              = errors.New("missing required field")
    35  	ErrNoSource                  = errors.New("no source for value")
    36  	ErrOverflow                  = errors.New("arithmetic overflow/underflow")
    37  	ErrPosition                  = errors.New("invalid source or destination position")
    38  	ErrUnbalanced                = errors.New("unbalanced asset amount between input and output")
    39  	ErrOverGasCredit             = errors.New("all gas credit has been spend")
    40  	ErrGasCalculate              = errors.New("gas usage calculate got a math error")
    41  )
    42  
    43  // GasState record the gas usage status
    44  type GasState struct {
    45  	BTMValue   uint64
    46  	GasLeft    int64
    47  	GasUsed    int64
    48  	GasValid   bool
    49  	StorageGas int64
    50  }
    51  
    52  func (g *GasState) setGas(BTMValue int64, txSize int64) error {
    53  	if BTMValue < 0 {
    54  		return errors.Wrap(ErrGasCalculate, "input BTM is negative")
    55  	}
    56  
    57  	g.BTMValue = uint64(BTMValue)
    58  
    59  	var ok bool
    60  	if g.GasLeft, ok = checked.DivInt64(BTMValue, consensus.VMGasRate); !ok {
    61  		return errors.Wrap(ErrGasCalculate, "setGas calc gas amount")
    62  	}
    63  
    64  	if g.GasLeft > consensus.MaxGasAmount {
    65  		g.GasLeft = consensus.MaxGasAmount
    66  	}
    67  
    68  	if g.StorageGas, ok = checked.MulInt64(txSize, consensus.StorageGasRate); !ok {
    69  		return errors.Wrap(ErrGasCalculate, "setGas calc tx storage gas")
    70  	}
    71  	return nil
    72  }
    73  
    74  func (g *GasState) setGasValid() error {
    75  	var ok bool
    76  	if g.GasLeft, ok = checked.SubInt64(g.GasLeft, g.StorageGas); !ok || g.GasLeft < 0 {
    77  		return errors.Wrap(ErrGasCalculate, "setGasValid calc gasLeft")
    78  	}
    79  
    80  	if g.GasUsed, ok = checked.AddInt64(g.GasUsed, g.StorageGas); !ok {
    81  		return errors.Wrap(ErrGasCalculate, "setGasValid calc gasUsed")
    82  	}
    83  
    84  	g.GasValid = true
    85  	return nil
    86  }
    87  
    88  func (g *GasState) updateUsage(gasLeft int64) error {
    89  	if gasLeft < 0 {
    90  		return errors.Wrap(ErrGasCalculate, "updateUsage input negative gas")
    91  	}
    92  
    93  	if gasUsed, ok := checked.SubInt64(g.GasLeft, gasLeft); ok {
    94  		g.GasUsed += gasUsed
    95  		g.GasLeft = gasLeft
    96  	} else {
    97  		return errors.Wrap(ErrGasCalculate, "updateUsage calc gas diff")
    98  	}
    99  
   100  	if !g.GasValid && (g.GasUsed > consensus.DefaultGasCredit || g.StorageGas > g.GasLeft) {
   101  		return ErrOverGasCredit
   102  	}
   103  	return nil
   104  }
   105  
   106  // validationState contains the context that must propagate through
   107  // the transaction graph when validating entries.
   108  type validationState struct {
   109  	block     *bc.Block
   110  	tx        *bc.Tx
   111  	gasStatus *GasState
   112  	entryID   bc.Hash           // The ID of the nearest enclosing entry
   113  	sourcePos uint64            // The source position, for validate ValueSources
   114  	destPos   uint64            // The destination position, for validate ValueDestinations
   115  	cache     map[bc.Hash]error // Memoized per-entry validation results
   116  }
   117  
   118  func checkValid(vs *validationState, e bc.Entry) (err error) {
   119  	var ok bool
   120  	entryID := bc.EntryID(e)
   121  	if err, ok = vs.cache[entryID]; ok {
   122  		return err
   123  	}
   124  
   125  	defer func() {
   126  		vs.cache[entryID] = err
   127  	}()
   128  
   129  	switch e := e.(type) {
   130  	case *bc.TxHeader:
   131  		for i, resID := range e.ResultIds {
   132  			resultEntry := vs.tx.Entries[*resID]
   133  			vs2 := *vs
   134  			vs2.entryID = *resID
   135  			if err = checkValid(&vs2, resultEntry); err != nil {
   136  				return errors.Wrapf(err, "checking result %d", i)
   137  			}
   138  		}
   139  
   140  		if e.Version == 1 && len(e.ResultIds) == 0 {
   141  			return ErrEmptyResults
   142  		}
   143  
   144  	case *bc.Mux:
   145  		parity := make(map[bc.AssetID]int64)
   146  		for i, src := range e.Sources {
   147  			if src.Value.Amount > math.MaxInt64 {
   148  				return errors.WithDetailf(ErrOverflow, "amount %d exceeds maximum value 2^63", src.Value.Amount)
   149  			}
   150  			sum, ok := checked.AddInt64(parity[*src.Value.AssetId], int64(src.Value.Amount))
   151  			if !ok {
   152  				return errors.WithDetailf(ErrOverflow, "adding %d units of asset %x from mux source %d to total %d overflows int64", src.Value.Amount, src.Value.AssetId.Bytes(), i, parity[*src.Value.AssetId])
   153  			}
   154  			parity[*src.Value.AssetId] = sum
   155  		}
   156  
   157  		for i, dest := range e.WitnessDestinations {
   158  			sum, ok := parity[*dest.Value.AssetId]
   159  			if !ok {
   160  				return errors.WithDetailf(ErrNoSource, "mux destination %d, asset %x, has no corresponding source", i, dest.Value.AssetId.Bytes())
   161  			}
   162  			if dest.Value.Amount > math.MaxInt64 {
   163  				return errors.WithDetailf(ErrOverflow, "amount %d exceeds maximum value 2^63", dest.Value.Amount)
   164  			}
   165  			diff, ok := checked.SubInt64(sum, int64(dest.Value.Amount))
   166  			if !ok {
   167  				return errors.WithDetailf(ErrOverflow, "subtracting %d units of asset %x from mux destination %d from total %d underflows int64", dest.Value.Amount, dest.Value.AssetId.Bytes(), i, sum)
   168  			}
   169  			parity[*dest.Value.AssetId] = diff
   170  		}
   171  
   172  		for assetID, amount := range parity {
   173  			if assetID == *consensus.BTMAssetID {
   174  				if err = vs.gasStatus.setGas(amount, int64(vs.tx.SerializedSize)); err != nil {
   175  					return err
   176  				}
   177  			} else if amount != 0 {
   178  				return errors.WithDetailf(ErrUnbalanced, "asset %x sources - destinations = %d (should be 0)", assetID.Bytes(), amount)
   179  			}
   180  		}
   181  
   182  		for _, BTMInputID := range vs.tx.GasInputIDs {
   183  			e, ok := vs.tx.Entries[BTMInputID]
   184  			if !ok {
   185  				return errors.Wrapf(bc.ErrMissingEntry, "entry for bytom input %x not found", BTMInputID)
   186  			}
   187  
   188  			vs2 := *vs
   189  			vs2.entryID = BTMInputID
   190  			if err := checkValid(&vs2, e); err != nil {
   191  				return errors.Wrap(err, "checking gas input")
   192  			}
   193  		}
   194  
   195  		for i, dest := range e.WitnessDestinations {
   196  			vs2 := *vs
   197  			vs2.destPos = uint64(i)
   198  			if err = checkValidDest(&vs2, dest); err != nil {
   199  				return errors.Wrapf(err, "checking mux destination %d", i)
   200  			}
   201  		}
   202  
   203  		if err := vs.gasStatus.setGasValid(); err != nil {
   204  			return err
   205  		}
   206  
   207  		for i, src := range e.Sources {
   208  			vs2 := *vs
   209  			vs2.sourcePos = uint64(i)
   210  			if err = checkValidSrc(&vs2, src); err != nil {
   211  				return errors.Wrapf(err, "checking mux source %d", i)
   212  			}
   213  		}
   214  
   215  	case *bc.Output:
   216  		vs2 := *vs
   217  		vs2.sourcePos = 0
   218  		if err = checkValidSrc(&vs2, e.Source); err != nil {
   219  			return errors.Wrap(err, "checking output source")
   220  		}
   221  
   222  	case *bc.Retirement:
   223  		vs2 := *vs
   224  		vs2.sourcePos = 0
   225  		if err = checkValidSrc(&vs2, e.Source); err != nil {
   226  			return errors.Wrap(err, "checking retirement source")
   227  		}
   228  
   229  	case *bc.Issuance:
   230  		computedAssetID := e.WitnessAssetDefinition.ComputeAssetID()
   231  		if computedAssetID != *e.Value.AssetId {
   232  			return errors.WithDetailf(ErrMismatchedAssetID, "asset ID is %x, issuance wants %x", computedAssetID.Bytes(), e.Value.AssetId.Bytes())
   233  		}
   234  
   235  		gasLeft, err := vm.Verify(NewTxVMContext(vs, e, e.WitnessAssetDefinition.IssuanceProgram, e.WitnessArguments), vs.gasStatus.GasLeft)
   236  		if err != nil {
   237  			return errors.Wrap(err, "checking issuance program")
   238  		}
   239  		if err = vs.gasStatus.updateUsage(gasLeft); err != nil {
   240  			return err
   241  		}
   242  
   243  		destVS := *vs
   244  		destVS.destPos = 0
   245  		if err = checkValidDest(&destVS, e.WitnessDestination); err != nil {
   246  			return errors.Wrap(err, "checking issuance destination")
   247  		}
   248  
   249  	case *bc.Spend:
   250  		if e.SpentOutputId == nil {
   251  			return errors.Wrap(ErrMissingField, "spend without spent output ID")
   252  		}
   253  		spentOutput, err := vs.tx.Output(*e.SpentOutputId)
   254  		if err != nil {
   255  			return errors.Wrap(err, "getting spend prevout")
   256  		}
   257  
   258  		gasLeft, err := vm.Verify(NewTxVMContext(vs, e, spentOutput.ControlProgram, e.WitnessArguments), vs.gasStatus.GasLeft)
   259  		if err != nil {
   260  			return errors.Wrap(err, "checking control program")
   261  		}
   262  		if err = vs.gasStatus.updateUsage(gasLeft); err != nil {
   263  			return err
   264  		}
   265  
   266  		eq, err := spentOutput.Source.Value.Equal(e.WitnessDestination.Value)
   267  		if err != nil {
   268  			return err
   269  		}
   270  		if !eq {
   271  			return errors.WithDetailf(
   272  				ErrMismatchedValue,
   273  				"previous output is for %d unit(s) of %x, spend wants %d unit(s) of %x",
   274  				spentOutput.Source.Value.Amount,
   275  				spentOutput.Source.Value.AssetId.Bytes(),
   276  				e.WitnessDestination.Value.Amount,
   277  				e.WitnessDestination.Value.AssetId.Bytes(),
   278  			)
   279  		}
   280  
   281  		vs2 := *vs
   282  		vs2.destPos = 0
   283  		if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
   284  			return errors.Wrap(err, "checking spend destination")
   285  		}
   286  
   287  	case *bc.Coinbase:
   288  		if vs.block == nil || len(vs.block.Transactions) == 0 || vs.block.Transactions[0] != vs.tx {
   289  			return ErrWrongCoinbaseTransaction
   290  		}
   291  
   292  		if *e.WitnessDestination.Value.AssetId != *consensus.BTMAssetID {
   293  			return ErrWrongCoinbaseAsset
   294  		}
   295  
   296  		if e.Arbitrary != nil && len(e.Arbitrary) > consensus.CoinbaseArbitrarySizeLimit {
   297  			return ErrCoinbaseArbitraryOversize
   298  		}
   299  
   300  		vs2 := *vs
   301  		vs2.destPos = 0
   302  		if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
   303  			return errors.Wrap(err, "checking coinbase destination")
   304  		}
   305  		vs.gasStatus.StorageGas = 0
   306  
   307  	default:
   308  		return fmt.Errorf("entry has unexpected type %T", e)
   309  	}
   310  
   311  	return nil
   312  }
   313  
   314  func checkValidSrc(vstate *validationState, vs *bc.ValueSource) error {
   315  	if vs == nil {
   316  		return errors.Wrap(ErrMissingField, "empty value source")
   317  	}
   318  	if vs.Ref == nil {
   319  		return errors.Wrap(ErrMissingField, "missing ref on value source")
   320  	}
   321  	if vs.Value == nil || vs.Value.AssetId == nil {
   322  		return errors.Wrap(ErrMissingField, "missing value on value source")
   323  	}
   324  
   325  	e, ok := vstate.tx.Entries[*vs.Ref]
   326  	if !ok {
   327  		return errors.Wrapf(bc.ErrMissingEntry, "entry for value source %x not found", vs.Ref.Bytes())
   328  	}
   329  
   330  	vstate2 := *vstate
   331  	vstate2.entryID = *vs.Ref
   332  	if err := checkValid(&vstate2, e); err != nil {
   333  		return errors.Wrap(err, "checking value source")
   334  	}
   335  
   336  	var dest *bc.ValueDestination
   337  	switch ref := e.(type) {
   338  	case *bc.Coinbase:
   339  		if vs.Position != 0 {
   340  			return errors.Wrapf(ErrPosition, "invalid position %d for coinbase source", vs.Position)
   341  		}
   342  		dest = ref.WitnessDestination
   343  
   344  	case *bc.Issuance:
   345  		if vs.Position != 0 {
   346  			return errors.Wrapf(ErrPosition, "invalid position %d for issuance source", vs.Position)
   347  		}
   348  		dest = ref.WitnessDestination
   349  
   350  	case *bc.Spend:
   351  		if vs.Position != 0 {
   352  			return errors.Wrapf(ErrPosition, "invalid position %d for spend source", vs.Position)
   353  		}
   354  		dest = ref.WitnessDestination
   355  
   356  	case *bc.Mux:
   357  		if vs.Position >= uint64(len(ref.WitnessDestinations)) {
   358  			return errors.Wrapf(ErrPosition, "invalid position %d for %d-destination mux source", vs.Position, len(ref.WitnessDestinations))
   359  		}
   360  		dest = ref.WitnessDestinations[vs.Position]
   361  
   362  	default:
   363  		return errors.Wrapf(bc.ErrEntryType, "value source is %T, should be coinbase, issuance, spend, or mux", e)
   364  	}
   365  
   366  	if dest.Ref == nil || *dest.Ref != vstate.entryID {
   367  		return errors.Wrapf(ErrMismatchedReference, "value source for %x has disagreeing destination %x", vstate.entryID.Bytes(), dest.Ref.Bytes())
   368  	}
   369  
   370  	if dest.Position != vstate.sourcePos {
   371  		return errors.Wrapf(ErrMismatchedPosition, "value source position %d disagrees with %d", dest.Position, vstate.sourcePos)
   372  	}
   373  
   374  	eq, err := dest.Value.Equal(vs.Value)
   375  	if err != nil {
   376  		return errors.Sub(ErrMissingField, err)
   377  	}
   378  	if !eq {
   379  		return errors.Wrapf(ErrMismatchedValue, "source value %v disagrees with %v", dest.Value, vs.Value)
   380  	}
   381  
   382  	return nil
   383  }
   384  
   385  func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
   386  	if vd == nil {
   387  		return errors.Wrap(ErrMissingField, "empty value destination")
   388  	}
   389  	if vd.Ref == nil {
   390  		return errors.Wrap(ErrMissingField, "missing ref on value destination")
   391  	}
   392  	if vd.Value == nil || vd.Value.AssetId == nil {
   393  		return errors.Wrap(ErrMissingField, "missing value on value destination")
   394  	}
   395  
   396  	e, ok := vs.tx.Entries[*vd.Ref]
   397  	if !ok {
   398  		return errors.Wrapf(bc.ErrMissingEntry, "entry for value destination %x not found", vd.Ref.Bytes())
   399  	}
   400  
   401  	var src *bc.ValueSource
   402  	switch ref := e.(type) {
   403  	case *bc.Output:
   404  		if vd.Position != 0 {
   405  			return errors.Wrapf(ErrPosition, "invalid position %d for output destination", vd.Position)
   406  		}
   407  		src = ref.Source
   408  
   409  	case *bc.Retirement:
   410  		if vd.Position != 0 {
   411  			return errors.Wrapf(ErrPosition, "invalid position %d for retirement destination", vd.Position)
   412  		}
   413  		src = ref.Source
   414  
   415  	case *bc.Mux:
   416  		if vd.Position >= uint64(len(ref.Sources)) {
   417  			return errors.Wrapf(ErrPosition, "invalid position %d for %d-source mux destination", vd.Position, len(ref.Sources))
   418  		}
   419  		src = ref.Sources[vd.Position]
   420  
   421  	default:
   422  		return errors.Wrapf(bc.ErrEntryType, "value destination is %T, should be output, retirement, or mux", e)
   423  	}
   424  
   425  	if src.Ref == nil || *src.Ref != vs.entryID {
   426  		return errors.Wrapf(ErrMismatchedReference, "value destination for %x has disagreeing source %x", vs.entryID.Bytes(), src.Ref.Bytes())
   427  	}
   428  
   429  	if src.Position != vs.destPos {
   430  		return errors.Wrapf(ErrMismatchedPosition, "value destination position %d disagrees with %d", src.Position, vs.destPos)
   431  	}
   432  
   433  	eq, err := src.Value.Equal(vd.Value)
   434  	if err != nil {
   435  		return errors.Sub(ErrMissingField, err)
   436  	}
   437  	if !eq {
   438  		return errors.Wrapf(ErrMismatchedValue, "destination value %v disagrees with %v", src.Value, vd.Value)
   439  	}
   440  
   441  	return nil
   442  }
   443  
   444  func checkStandardTx(tx *bc.Tx, blockHeight uint64) error {
   445  	for _, id := range tx.InputIDs {
   446  		if blockHeight >= ruleAA && id.IsZero() {
   447  			return ErrEmptyInputIDs
   448  		}
   449  	}
   450  
   451  	for _, id := range tx.GasInputIDs {
   452  		spend, err := tx.Spend(id)
   453  		if err != nil {
   454  			continue
   455  		}
   456  		spentOutput, err := tx.Output(*spend.SpentOutputId)
   457  		if err != nil {
   458  			return err
   459  		}
   460  
   461  		if !segwit.IsP2WScript(spentOutput.ControlProgram.Code) {
   462  			return ErrNotStandardTx
   463  		}
   464  	}
   465  
   466  	for _, id := range tx.ResultIds {
   467  		e, ok := tx.Entries[*id]
   468  		if !ok {
   469  			return errors.Wrapf(bc.ErrMissingEntry, "id %x", id.Bytes())
   470  		}
   471  
   472  		output, ok := e.(*bc.Output)
   473  		if !ok || *output.Source.Value.AssetId != *consensus.BTMAssetID {
   474  			continue
   475  		}
   476  
   477  		if !segwit.IsP2WScript(output.ControlProgram.Code) {
   478  			return ErrNotStandardTx
   479  		}
   480  	}
   481  	return nil
   482  }
   483  
   484  func checkTimeRange(tx *bc.Tx, block *bc.Block) error {
   485  	if tx.TimeRange == 0 {
   486  		return nil
   487  	}
   488  
   489  	if tx.TimeRange < block.Height {
   490  		return ErrBadTimeRange
   491  	}
   492  	return nil
   493  }
   494  
   495  // ValidateTx validates a transaction.
   496  func ValidateTx(tx *bc.Tx, block *bc.Block) (*GasState, error) {
   497  	gasStatus := &GasState{GasValid: false}
   498  	if block.Version == 1 && tx.Version != 1 {
   499  		return gasStatus, errors.WithDetailf(ErrTxVersion, "block version %d, transaction version %d", block.Version, tx.Version)
   500  	}
   501  	if tx.SerializedSize == 0 {
   502  		return gasStatus, ErrWrongTransactionSize
   503  	}
   504  	if err := checkTimeRange(tx, block); err != nil {
   505  		return gasStatus, err
   506  	}
   507  	if err := checkStandardTx(tx, block.Height); err != nil {
   508  		return gasStatus, err
   509  	}
   510  
   511  	vs := &validationState{
   512  		block:     block,
   513  		tx:        tx,
   514  		entryID:   tx.ID,
   515  		gasStatus: gasStatus,
   516  		cache:     make(map[bc.Hash]error),
   517  	}
   518  	return vs.gasStatus, checkValid(vs, tx.TxHeader)
   519  }
   520  
   521  type validateTxWork struct {
   522  	i     int
   523  	tx    *bc.Tx
   524  	block *bc.Block
   525  }
   526  
   527  // ValidateTxResult is the result of async tx validate
   528  type ValidateTxResult struct {
   529  	i         int
   530  	gasStatus *GasState
   531  	err       error
   532  }
   533  
   534  // GetGasState return the gasStatus
   535  func (r *ValidateTxResult) GetGasState() *GasState {
   536  	return r.gasStatus
   537  }
   538  
   539  // GetError return the err
   540  func (r *ValidateTxResult) GetError() error {
   541  	return r.err
   542  }
   543  
   544  func validateTxWorker(workCh chan *validateTxWork, resultCh chan *ValidateTxResult, wg *sync.WaitGroup) {
   545  	for work := range workCh {
   546  		gasStatus, err := ValidateTx(work.tx, work.block)
   547  		resultCh <- &ValidateTxResult{i: work.i, gasStatus: gasStatus, err: err}
   548  	}
   549  	wg.Done()
   550  }
   551  
   552  // ValidateTxs validates txs in async mode
   553  func ValidateTxs(txs []*bc.Tx, block *bc.Block) []*ValidateTxResult {
   554  	txSize := len(txs)
   555  	validateWorkerNum := runtime.NumCPU()
   556  	//init the goroutine validate worker
   557  	var wg sync.WaitGroup
   558  	workCh := make(chan *validateTxWork, txSize)
   559  	resultCh := make(chan *ValidateTxResult, txSize)
   560  	for i := 0; i <= validateWorkerNum && i < txSize; i++ {
   561  		wg.Add(1)
   562  		go validateTxWorker(workCh, resultCh, &wg)
   563  	}
   564  
   565  	//sent the works
   566  	for i, tx := range txs {
   567  		workCh <- &validateTxWork{i: i, tx: tx, block: block}
   568  	}
   569  	close(workCh)
   570  
   571  	//collect validate results
   572  	results := make([]*ValidateTxResult, txSize)
   573  	for i := 0; i < txSize; i++ {
   574  		result := <-resultCh
   575  		results[result.i] = result
   576  	}
   577  
   578  	wg.Wait()
   579  	return results
   580  }