istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/ledger/smt_test.go (about)

     1  // Copyright 2019 Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ledger
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/rand"
    20  	"fmt"
    21  	"runtime"
    22  	"sort"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"istio.io/istio/pkg/cache"
    28  	"istio.io/istio/pkg/test/util/assert"
    29  )
    30  
    31  func TestSmtEmptyTrie(t *testing.T) {
    32  	smt := newSMT(hasher, nil, time.Minute)
    33  	if !bytes.Equal([]byte{}, smt.root) {
    34  		t.Fatal("empty trie root hash not correct")
    35  	}
    36  }
    37  
    38  func TestSmtUpdateAndGet(t *testing.T) {
    39  	smt := newSMT(hasher, nil, time.Minute)
    40  	smt.atomicUpdate = false
    41  
    42  	// Add data to empty trie
    43  	keys := getFreshData(10)
    44  	values := getFreshData(10)
    45  	ch := make(chan result, 1)
    46  	smt.update(smt.root, keys, values, nil, 0, smt.trieHeight, false, true, ch)
    47  	res := <-ch
    48  	root := res.update
    49  
    50  	// Check all keys have been stored
    51  	for i, key := range keys {
    52  		value, _ := smt.get(root, key, nil, 0, smt.trieHeight)
    53  		if !bytes.Equal(values[i], value) {
    54  			t.Fatal("value not updated")
    55  		}
    56  	}
    57  
    58  	// Append to the trie
    59  	newKeys := getFreshData(5)
    60  	newValues := getFreshData(5)
    61  	ch = make(chan result, 1)
    62  	smt.update(root, newKeys, newValues, nil, 0, smt.trieHeight, false, true, ch)
    63  	res = <-ch
    64  	newRoot := res.update
    65  	if bytes.Equal(root, newRoot) {
    66  		t.Fatal("trie not updated")
    67  	}
    68  	for i, newKey := range newKeys {
    69  		newValue, _ := smt.get(newRoot, newKey, nil, 0, smt.trieHeight)
    70  		if !bytes.Equal(newValues[i], newValue) {
    71  			t.Fatal("failed to get value")
    72  		}
    73  	}
    74  	// Check old keys are still stored
    75  	for i, key := range keys {
    76  		value, _ := smt.get(newRoot, key, nil, 0, smt.trieHeight)
    77  		if !bytes.Equal(values[i], value) {
    78  			t.Fatal("failed to get value")
    79  		}
    80  	}
    81  }
    82  
    83  func TestTrieAtomicUpdate(t *testing.T) {
    84  	smt := newSMT(hasher, nil, time.Minute)
    85  	keys := getFreshData(10)
    86  	values := getFreshData(10)
    87  	root, _ := smt.Update(keys, values)
    88  
    89  	// check keys of previous atomic update are accessible in
    90  	// updated nodes with root.
    91  	smt.atomicUpdate = false
    92  	for i, key := range keys {
    93  		value, _ := smt.get(root, key, nil, 0, smt.trieHeight)
    94  		if !bytes.Equal(values[i], value) {
    95  			t.Fatal("failed to get value")
    96  		}
    97  	}
    98  }
    99  
   100  func TestSmtPublicUpdateAndGet(t *testing.T) {
   101  	smt := newSMT(hasher, nil, time.Minute)
   102  	// Add data to empty trie
   103  	keys := getFreshData(5)
   104  	values := getFreshData(5)
   105  	root, _ := smt.Update(keys, values)
   106  
   107  	// Check all keys have been stored
   108  	for i, key := range keys {
   109  		value, _ := smt.Get(key)
   110  		if !bytes.Equal(values[i], value) {
   111  			t.Fatal("trie not updated")
   112  		}
   113  	}
   114  	if !bytes.Equal(root, smt.root) {
   115  		t.Fatal("root not stored")
   116  	}
   117  
   118  	newValues := getFreshData(5)
   119  	_, err := smt.Update(keys, newValues)
   120  	assert.NoError(t, err)
   121  
   122  	// Check all keys have been modified
   123  	for i, key := range keys {
   124  		value, _ := smt.Get(key)
   125  		if !bytes.Equal(newValues[i], value) {
   126  			t.Fatal("trie not updated")
   127  		}
   128  	}
   129  
   130  	newKeys := getFreshData(5)
   131  	newValues = getFreshData(5)
   132  	_, err = smt.Update(newKeys, newValues)
   133  	assert.NoError(t, err)
   134  	for i, key := range newKeys {
   135  		value, _ := smt.Get(key)
   136  		if !bytes.Equal(newValues[i], value) {
   137  			t.Fatal("trie not updated")
   138  		}
   139  	}
   140  }
   141  
   142  func TestSmtDelete(t *testing.T) {
   143  	smt := newSMT(hasher, nil, time.Minute)
   144  	// Add data to empty trie
   145  	keys := getFreshData(10)
   146  	values := getFreshData(10)
   147  	ch := make(chan result, 1)
   148  	smt.update(smt.root, keys, values, nil, 0, smt.trieHeight, false, true, ch)
   149  	res := <-ch
   150  	root := res.update
   151  	value, _ := smt.get(root, keys[0], nil, 0, smt.trieHeight)
   152  	if !bytes.Equal(values[0], value) {
   153  		t.Fatal("trie not updated")
   154  	}
   155  
   156  	// Delete from trie
   157  	// To delete a key, just set it's value to Default leaf hash.
   158  	ch = make(chan result, 1)
   159  	smt.update(root, keys[0:1], [][]byte{defaultLeaf}, nil, 0, smt.trieHeight, false, true, ch)
   160  	res = <-ch
   161  	newRoot := res.update
   162  	newValue, _ := smt.get(newRoot, keys[0], nil, 0, smt.trieHeight)
   163  	if len(newValue) != 0 {
   164  		t.Fatal("Failed to delete from trie")
   165  	}
   166  	// Remove deleted key from keys and check root with a clean trie.
   167  	smt2 := newSMT(hasher, nil, time.Minute)
   168  	ch = make(chan result, 1)
   169  	smt2.update(smt2.root, keys[1:], values[1:], nil, 0, smt.trieHeight, false, true, ch)
   170  	res = <-ch
   171  	cleanRoot := res.update
   172  	if !bytes.Equal(newRoot, cleanRoot) {
   173  		t.Fatal("roots mismatch")
   174  	}
   175  
   176  	// Empty the trie
   177  	var newValues [][]byte
   178  	for i := 0; i < 10; i++ {
   179  		newValues = append(newValues, defaultLeaf)
   180  	}
   181  	ch = make(chan result, 1)
   182  	smt.update(root, keys, newValues, nil, 0, smt.trieHeight, false, true, ch)
   183  	res = <-ch
   184  	root = res.update
   185  	if len(root) != 0 {
   186  		t.Fatal("empty trie root hash not correct")
   187  	}
   188  	// Test deleting an already empty key
   189  	smt = newSMT(hasher, nil, time.Minute)
   190  	keys = getFreshData(2)
   191  	values = getFreshData(2)
   192  	root, _ = smt.Update(keys, values)
   193  	key0 := make([]byte, 8)
   194  	key1 := make([]byte, 8)
   195  	_, err := smt.Update([][]byte{key0, key1}, [][]byte{defaultLeaf, defaultLeaf})
   196  	assert.NoError(t, err)
   197  	if !bytes.Equal(root, smt.root) {
   198  		t.Fatal("deleting a default key shouldn't modify the tree")
   199  	}
   200  }
   201  
   202  // test updating and deleting at the same time
   203  func TestTrieUpdateAndDelete(t *testing.T) {
   204  	smt := newSMT(hasher, nil, time.Minute)
   205  	key0 := make([]byte, 8)
   206  	values := getFreshData(1)
   207  	root, _ := smt.Update([][]byte{key0}, values)
   208  	smt.atomicUpdate = false
   209  	_, _, k, v, isShortcut, _ := smt.loadChildren(root, smt.trieHeight, 0, nil)
   210  	if !isShortcut || !bytes.Equal(k[:hashLength], key0) || !bytes.Equal(v[:hashLength], values[0]) {
   211  		t.Fatal("leaf shortcut didn't move up to root")
   212  	}
   213  
   214  	key1 := make([]byte, 8)
   215  	// set the last bit
   216  	bitSet(key1, 63)
   217  	keys := [][]byte{key0, key1}
   218  	values = [][]byte{defaultLeaf, getFreshData(1)[0]}
   219  	_, err := smt.Update(keys, values)
   220  	assert.NoError(t, err)
   221  }
   222  
   223  func bitSet(bits []byte, i int) {
   224  	bits[i/8] |= 1 << uint(7-i%8)
   225  }
   226  
   227  func TestSmtRaisesError(t *testing.T) {
   228  	smt := newSMT(hasher, nil, time.Minute)
   229  	// Add data to empty trie
   230  	keys := getFreshData(10)
   231  	values := getFreshData(10)
   232  	_, err := smt.Update(keys, values)
   233  	assert.NoError(t, err)
   234  	smt.db.updatedNodes = byteCache{cache: cache.NewTTL(forever, time.Minute)}
   235  	smt.loadDefaultHashes()
   236  
   237  	// Check errors are raised is a keys is not in cache nor db
   238  	for _, key := range keys {
   239  		_, err := smt.Get(key)
   240  		assert.Error(t, err)
   241  		assert.Equal(t, strings.Contains(err.Error(), "is unavailable in the disk db"), true,
   242  			"Error not created if database doesn't have a node")
   243  	}
   244  }
   245  
   246  // nolint: gosec
   247  // test only code
   248  func getFreshData(size int) [][]byte {
   249  	length := 8
   250  	var data [][]byte
   251  	for i := 0; i < size; i++ {
   252  		key := make([]byte, 8)
   253  		_, err := rand.Read(key)
   254  		if err != nil {
   255  			panic(err)
   256  		}
   257  		data = append(data, hasher(key)[:length])
   258  	}
   259  	sort.Sort(dataArray(data))
   260  	return data
   261  }
   262  
   263  func benchmark10MAccounts10Ktps(smt *smt, b *testing.B) {
   264  	fmt.Println("\nLoading b.N x 1000 accounts")
   265  	for index := 0; index < b.N; index++ {
   266  		newkeys := getFreshData(1000)
   267  		newvalues := getFreshData(1000)
   268  		start := time.Now()
   269  		smt.Update(newkeys, newvalues)
   270  		end := time.Now()
   271  		end2 := time.Now()
   272  		for i, key := range newkeys {
   273  			val, _ := smt.Get(key)
   274  			if !bytes.Equal(val, newvalues[i]) {
   275  				b.Fatal("new key not included")
   276  			}
   277  		}
   278  		end3 := time.Now()
   279  		elapsed := end.Sub(start)
   280  		elapsed2 := end2.Sub(end)
   281  		elapsed3 := end3.Sub(end2)
   282  		var m runtime.MemStats
   283  		runtime.ReadMemStats(&m)
   284  		fmt.Println(index, " : update time : ", elapsed, "commit time : ", elapsed2,
   285  			"\n1000 Get time : ", elapsed3,
   286  			"\nRAM : ", m.Sys/1024/1024, " MiB")
   287  	}
   288  }
   289  
   290  // go test -run=xxx -bench=. -benchmem -test.benchtime=20s
   291  func BenchmarkCacheHeightLimit(b *testing.B) {
   292  	smt := newSMT(hasher, cache.NewTTL(forever, time.Minute), time.Minute)
   293  	benchmark10MAccounts10Ktps(smt, b)
   294  }