github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/transaction/witness_condition_test.go (about)

     1  package transaction
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
     9  	"github.com/nspcc-dev/neo-go/pkg/io"
    10  	"github.com/nspcc-dev/neo-go/pkg/util"
    11  	"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  type InvalidCondition struct{}
    17  
    18  func (c InvalidCondition) Type() WitnessConditionType {
    19  	return 0xff
    20  }
    21  func (c InvalidCondition) Match(_ MatchContext) (bool, error) {
    22  	return true, nil
    23  }
    24  func (c InvalidCondition) EncodeBinary(w *io.BinWriter) {
    25  	w.WriteB(byte(c.Type()))
    26  }
    27  func (c InvalidCondition) DecodeBinarySpecific(r *io.BinReader, _ int) {
    28  }
    29  func (c InvalidCondition) MarshalJSON() ([]byte, error) {
    30  	aux := conditionAux{
    31  		Type: c.Type().String(),
    32  	}
    33  	return json.Marshal(aux)
    34  }
    35  func (c InvalidCondition) ToStackItem() stackitem.Item {
    36  	panic("invalid")
    37  }
    38  
    39  // Copy implements the WitnessCondition interface and returns a deep copy of the condition.
    40  func (c InvalidCondition) Copy() WitnessCondition {
    41  	return c
    42  }
    43  
    44  type condCase struct {
    45  	condition         WitnessCondition
    46  	success           bool
    47  	expectedStackItem []stackitem.Item
    48  }
    49  
    50  func TestWitnessConditionSerDes(t *testing.T) {
    51  	var someBool bool
    52  	pk, err := keys.NewPrivateKey()
    53  	require.NoError(t, err)
    54  	var cases = []condCase{
    55  		{(*ConditionBoolean)(&someBool), true, []stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)}},
    56  		{&ConditionNot{(*ConditionBoolean)(&someBool)}, true, []stackitem.Item{stackitem.Make(WitnessNot), stackitem.NewArray([]stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)})}},
    57  		{&ConditionAnd{(*ConditionBoolean)(&someBool), (*ConditionBoolean)(&someBool)}, true, []stackitem.Item{stackitem.Make(WitnessAnd), stackitem.Make([]stackitem.Item{
    58  			stackitem.NewArray([]stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)}),
    59  			stackitem.NewArray([]stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)}),
    60  		})}},
    61  		{&ConditionOr{(*ConditionBoolean)(&someBool), (*ConditionBoolean)(&someBool)}, true, []stackitem.Item{stackitem.Make(WitnessOr), stackitem.Make([]stackitem.Item{
    62  			stackitem.NewArray([]stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)}),
    63  			stackitem.NewArray([]stackitem.Item{stackitem.Make(WitnessBoolean), stackitem.Make(someBool)}),
    64  		})}},
    65  		{&ConditionScriptHash{1, 2, 3}, true, []stackitem.Item{stackitem.Make(WitnessScriptHash), stackitem.Make(util.Uint160{1, 2, 3}.BytesBE())}},
    66  		{(*ConditionGroup)(pk.PublicKey()), true, []stackitem.Item{stackitem.Make(WitnessGroup), stackitem.Make(pk.PublicKey().Bytes())}},
    67  		{ConditionCalledByEntry{}, true, []stackitem.Item{stackitem.Make(WitnessCalledByEntry)}},
    68  		{&ConditionCalledByContract{1, 2, 3}, true, []stackitem.Item{stackitem.Make(WitnessCalledByContract), stackitem.Make(util.Uint160{1, 2, 3}.BytesBE())}},
    69  		{(*ConditionCalledByGroup)(pk.PublicKey()), true, []stackitem.Item{stackitem.Make(WitnessCalledByGroup), stackitem.Make(pk.PublicKey().Bytes())}},
    70  		{InvalidCondition{}, false, nil},
    71  		{&ConditionAnd{}, false, nil},
    72  		{&ConditionOr{}, false, nil},
    73  		{&ConditionNot{&ConditionNot{&ConditionNot{(*ConditionBoolean)(&someBool)}}}, false, nil},
    74  	}
    75  	var maxSubCondAnd = &ConditionAnd{}
    76  	var maxSubCondOr = &ConditionAnd{}
    77  	for i := 0; i < maxSubitems+1; i++ {
    78  		*maxSubCondAnd = append(*maxSubCondAnd, (*ConditionBoolean)(&someBool))
    79  		*maxSubCondOr = append(*maxSubCondOr, (*ConditionBoolean)(&someBool))
    80  	}
    81  	cases = append(cases, condCase{maxSubCondAnd, false, nil})
    82  	cases = append(cases, condCase{maxSubCondOr, false, nil})
    83  	t.Run("binary", func(t *testing.T) {
    84  		for i, c := range cases {
    85  			w := io.NewBufBinWriter()
    86  			c.condition.EncodeBinary(w.BinWriter)
    87  			require.NoError(t, w.Err)
    88  			b := w.Bytes()
    89  
    90  			r := io.NewBinReaderFromBuf(b)
    91  			res := DecodeBinaryCondition(r)
    92  			if !c.success {
    93  				require.Nil(t, res)
    94  				require.Errorf(t, r.Err, "case %d", i)
    95  				continue
    96  			}
    97  			require.NoErrorf(t, r.Err, "case %d", i)
    98  			require.Equal(t, c.condition, res)
    99  		}
   100  	})
   101  	t.Run("json", func(t *testing.T) {
   102  		for i, c := range cases {
   103  			jj, err := c.condition.MarshalJSON()
   104  			require.NoError(t, err)
   105  			res, err := UnmarshalConditionJSON(jj)
   106  			if !c.success {
   107  				require.Errorf(t, err, "case %d, json %s", i, jj)
   108  				continue
   109  			}
   110  			require.NoErrorf(t, err, "case %d, json %s", i, jj)
   111  			require.Equal(t, c.condition, res)
   112  		}
   113  	})
   114  	t.Run("stackitem", func(t *testing.T) {
   115  		for i, c := range cases[1:] {
   116  			if c.expectedStackItem != nil {
   117  				expected := stackitem.NewArray(c.expectedStackItem)
   118  				actual := c.condition.ToStackItem()
   119  				assert.Equal(t, expected, actual, i)
   120  			}
   121  		}
   122  	})
   123  }
   124  
   125  func TestWitnessConditionZeroDeser(t *testing.T) {
   126  	r := io.NewBinReaderFromBuf([]byte{})
   127  	res := DecodeBinaryCondition(r)
   128  	require.Nil(t, res)
   129  	require.Error(t, r.Err)
   130  }
   131  
   132  func TestWitnessConditionJSONErrors(t *testing.T) {
   133  	var cases = []string{
   134  		`[]`,
   135  		`{}`,
   136  		`{"type":"Boolean"}`,
   137  		`{"type":"Not"}`,
   138  		`{"type":"And"}`,
   139  		`{"type":"Or"}`,
   140  		`{"type":"ScriptHash"}`,
   141  		`{"type":"Group"}`,
   142  		`{"type":"CalledByContract"}`,
   143  		`{"type":"CalledByGroup"}`,
   144  		`{"type":"Boolean", "expression":42}`,
   145  		`{"type":"Not", "expression":true}`,
   146  		`{"type":"And", "expressions":[{"type":"CalledByGroup"},{"type":"Not", "expression":true}]}`,
   147  		`{"type":"Or", "expressions":{"type":"CalledByGroup"}}`,
   148  		`{"type":"Or", "expressions":[{"type":"CalledByGroup"},{"type":"Not", "expression":false}]}`,
   149  		`{"type":"ScriptHash", "hash":"1122"}`,
   150  		`{"type":"Group", "group":"032211"}`,
   151  		`{"type":"CalledByContract", "hash":"1122"}`,
   152  		`{"type":"CalledByGroup", "group":"032211"}`,
   153  	}
   154  	for i := range cases {
   155  		res, err := UnmarshalConditionJSON([]byte(cases[i]))
   156  		require.Errorf(t, err, "case %d, json %s", i, cases[i])
   157  		require.Nil(t, res)
   158  	}
   159  }
   160  
   161  type TestMC struct {
   162  	calling util.Uint160
   163  	current util.Uint160
   164  	entry   util.Uint160
   165  	goodKey *keys.PublicKey
   166  	badKey  *keys.PublicKey
   167  }
   168  
   169  func (t *TestMC) GetCallingScriptHash() util.Uint160 {
   170  	return t.calling
   171  }
   172  func (t *TestMC) GetCurrentScriptHash() util.Uint160 {
   173  	return t.current
   174  }
   175  func (t *TestMC) GetEntryScriptHash() util.Uint160 {
   176  	return t.entry
   177  }
   178  func (t *TestMC) IsCalledByEntry() bool {
   179  	return t.entry.Equals(t.calling) || t.calling.Equals(util.Uint160{})
   180  }
   181  func (t *TestMC) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) {
   182  	res, err := t.CurrentScriptHasGroup(k)
   183  	return !res, err // To differentiate from current we invert the logic value.
   184  }
   185  func (t *TestMC) CurrentScriptHasGroup(k *keys.PublicKey) (bool, error) {
   186  	if k.Equal(t.goodKey) {
   187  		return true, nil
   188  	}
   189  	if k.Equal(t.badKey) {
   190  		return false, errors.New("baaad key")
   191  	}
   192  	return false, nil
   193  }
   194  
   195  func TestWitnessConditionMatch(t *testing.T) {
   196  	pkGood, err := keys.NewPrivateKey()
   197  	require.NoError(t, err)
   198  	pkBad, err := keys.NewPrivateKey()
   199  	require.NoError(t, err)
   200  	pkNeutral, err := keys.NewPrivateKey()
   201  	require.NoError(t, err)
   202  	entrySC := util.Uint160{1, 2, 3}
   203  	currentSC := util.Uint160{4, 5, 6}
   204  	tmc := &TestMC{
   205  		calling: entrySC,
   206  		entry:   entrySC,
   207  		current: currentSC,
   208  		goodKey: pkGood.PublicKey(),
   209  		badKey:  pkBad.PublicKey(),
   210  	}
   211  
   212  	t.Run("boolean", func(t *testing.T) {
   213  		var b bool
   214  		var c = (*ConditionBoolean)(&b)
   215  		res, err := c.Match(tmc)
   216  		require.NoError(t, err)
   217  		require.False(t, res)
   218  		b = true
   219  		res, err = c.Match(tmc)
   220  		require.NoError(t, err)
   221  		require.True(t, res)
   222  	})
   223  	t.Run("not", func(t *testing.T) {
   224  		var b bool
   225  		var cInner = (*ConditionBoolean)(&b)
   226  		var cInner2 = (*ConditionGroup)(pkBad.PublicKey())
   227  		var c = &ConditionNot{cInner}
   228  		var c2 = &ConditionNot{cInner2}
   229  
   230  		res, err := c.Match(tmc)
   231  		require.NoError(t, err)
   232  		require.True(t, res)
   233  		b = true
   234  		res, err = c.Match(tmc)
   235  		require.NoError(t, err)
   236  		require.False(t, res)
   237  		_, err = c2.Match(tmc)
   238  		require.Error(t, err)
   239  	})
   240  	t.Run("and", func(t *testing.T) {
   241  		var bFalse, bTrue bool
   242  		var cInnerFalse = (*ConditionBoolean)(&bFalse)
   243  		var cInnerTrue = (*ConditionBoolean)(&bTrue)
   244  		var cInnerBad = (*ConditionGroup)(pkBad.PublicKey())
   245  		var c = &ConditionAnd{cInnerTrue, cInnerFalse, cInnerFalse}
   246  		var cBad = &ConditionAnd{cInnerTrue, cInnerBad}
   247  
   248  		bTrue = true
   249  		res, err := c.Match(tmc)
   250  		require.NoError(t, err)
   251  		require.False(t, res)
   252  		bFalse = true
   253  		res, err = c.Match(tmc)
   254  		require.NoError(t, err)
   255  		require.True(t, res)
   256  
   257  		_, err = cBad.Match(tmc)
   258  		require.Error(t, err)
   259  	})
   260  	t.Run("or", func(t *testing.T) {
   261  		var bFalse, bTrue bool
   262  		var cInnerFalse = (*ConditionBoolean)(&bFalse)
   263  		var cInnerTrue = (*ConditionBoolean)(&bTrue)
   264  		var cInnerBad = (*ConditionGroup)(pkBad.PublicKey())
   265  		var c = &ConditionOr{cInnerTrue, cInnerFalse, cInnerFalse}
   266  		var cBad = &ConditionOr{cInnerTrue, cInnerBad}
   267  
   268  		bTrue = true
   269  		res, err := c.Match(tmc)
   270  		require.NoError(t, err)
   271  		require.True(t, res)
   272  		bTrue = false
   273  		res, err = c.Match(tmc)
   274  		require.NoError(t, err)
   275  		require.False(t, res)
   276  
   277  		_, err = cBad.Match(tmc)
   278  		require.Error(t, err)
   279  	})
   280  	t.Run("script hash", func(t *testing.T) {
   281  		var cEntry = (*ConditionScriptHash)(&entrySC)
   282  		var cCurrent = (*ConditionScriptHash)(&currentSC)
   283  
   284  		res, err := cEntry.Match(tmc)
   285  		require.NoError(t, err)
   286  		require.False(t, res)
   287  		res, err = cCurrent.Match(tmc)
   288  		require.NoError(t, err)
   289  		require.True(t, res)
   290  	})
   291  	t.Run("group", func(t *testing.T) {
   292  		var cBad = (*ConditionGroup)(pkBad.PublicKey())
   293  		var cGood = (*ConditionGroup)(pkGood.PublicKey())
   294  		var cNeutral = (*ConditionGroup)(pkNeutral.PublicKey())
   295  
   296  		res, err := cGood.Match(tmc)
   297  		require.NoError(t, err)
   298  		require.True(t, res)
   299  
   300  		res, err = cNeutral.Match(tmc)
   301  		require.NoError(t, err)
   302  		require.False(t, res)
   303  
   304  		_, err = cBad.Match(tmc)
   305  		require.Error(t, err)
   306  	})
   307  	t.Run("called by entry", func(t *testing.T) {
   308  		var c = ConditionCalledByEntry{}
   309  
   310  		res, err := c.Match(tmc)
   311  		require.NoError(t, err)
   312  		require.True(t, res)
   313  
   314  		tmc2 := *tmc
   315  		tmc2.entry = util.Uint160{0, 9, 8}
   316  		res, err = c.Match(&tmc2)
   317  		require.NoError(t, err)
   318  		require.False(t, res)
   319  
   320  		tmc3 := *tmc
   321  		tmc3.calling = util.Uint160{}
   322  		tmc3.current = tmc3.entry
   323  		res, err = c.Match(&tmc3)
   324  		require.NoError(t, err)
   325  		require.True(t, res)
   326  	})
   327  	t.Run("called by contract", func(t *testing.T) {
   328  		var cEntry = (*ConditionCalledByContract)(&entrySC)
   329  		var cCurrent = (*ConditionCalledByContract)(&currentSC)
   330  
   331  		res, err := cEntry.Match(tmc)
   332  		require.NoError(t, err)
   333  		require.True(t, res)
   334  		res, err = cCurrent.Match(tmc)
   335  		require.NoError(t, err)
   336  		require.False(t, res)
   337  	})
   338  	t.Run("called by group", func(t *testing.T) {
   339  		var cBad = (*ConditionCalledByGroup)(pkBad.PublicKey())
   340  		var cGood = (*ConditionCalledByGroup)(pkGood.PublicKey())
   341  		var cNeutral = (*ConditionCalledByGroup)(pkNeutral.PublicKey())
   342  
   343  		res, err := cGood.Match(tmc)
   344  		require.NoError(t, err)
   345  		require.False(t, res)
   346  
   347  		res, err = cNeutral.Match(tmc)
   348  		require.NoError(t, err)
   349  		require.True(t, res)
   350  
   351  		_, err = cBad.Match(tmc)
   352  		require.Error(t, err)
   353  	})
   354  }
   355  
   356  func TestWitnessConditionCopy(t *testing.T) {
   357  	var someBool = true
   358  	boolCondition := (*ConditionBoolean)(&someBool)
   359  	pk, err := keys.NewPrivateKey()
   360  	require.NoError(t, err)
   361  
   362  	conditions := []WitnessCondition{
   363  		boolCondition,
   364  		&ConditionNot{Condition: boolCondition},
   365  		&ConditionAnd{boolCondition, boolCondition},
   366  		&ConditionOr{boolCondition, boolCondition},
   367  		&ConditionScriptHash{1, 2, 3},
   368  		(*ConditionGroup)(pk.PublicKey()),
   369  		ConditionCalledByEntry{},
   370  		&ConditionCalledByContract{1, 2, 3},
   371  		(*ConditionCalledByGroup)(pk.PublicKey()),
   372  		&ConditionNot{Condition: &ConditionNot{Condition: &ConditionNot{Condition: boolCondition}}},
   373  	}
   374  	for _, cond := range conditions {
   375  		copied := cond.Copy()
   376  		require.Equal(t, cond, copied)
   377  
   378  		switch c := copied.(type) {
   379  		case *ConditionBoolean:
   380  			require.NotSame(t, c, cond.(*ConditionBoolean))
   381  		case *ConditionScriptHash:
   382  			c[0]++
   383  			require.NotEqual(t, c, cond.(*ConditionScriptHash))
   384  		case *ConditionGroup:
   385  			c = (*ConditionGroup)(pk.PublicKey())
   386  			require.NotSame(t, c, cond.(*ConditionGroup))
   387  			newPk, _ := keys.NewPrivateKey()
   388  			copied = (*ConditionGroup)(newPk.PublicKey())
   389  			require.NotEqual(t, copied, cond)
   390  		case *ConditionCalledByContract:
   391  			c[0]++
   392  			require.NotEqual(t, c, cond.(*ConditionCalledByContract))
   393  		case *ConditionCalledByGroup:
   394  			c = (*ConditionCalledByGroup)(pk.PublicKey())
   395  			require.NotSame(t, c, cond.(*ConditionCalledByGroup))
   396  			newPk, _ := keys.NewPrivateKey()
   397  			copied = (*ConditionCalledByGroup)(newPk.PublicKey())
   398  			require.NotEqual(t, copied, cond)
   399  		case *ConditionNot:
   400  			require.NotSame(t, copied, cond)
   401  			copied.(*ConditionNot).Condition = ConditionCalledByEntry{}
   402  			require.NotEqual(t, copied, cond)
   403  		case *ConditionAnd:
   404  			require.NotSame(t, copied, cond)
   405  			(*(copied.(*ConditionAnd)))[0] = ConditionCalledByEntry{}
   406  			require.NotEqual(t, copied, cond)
   407  		case *ConditionOr:
   408  			require.NotSame(t, copied, cond)
   409  			(*(copied.(*ConditionOr)))[0] = ConditionCalledByEntry{}
   410  			require.NotEqual(t, copied, cond)
   411  		case *ConditionCalledByEntry:
   412  			require.NotSame(t, copied, cond)
   413  			copied = ConditionCalledByEntry{}
   414  			require.NotEqual(t, copied, cond)
   415  		}
   416  	}
   417  }