github.com/onflow/atree@v0.6.0/utils_test.go (about)

     1  /*
     2   * Atree - Scalable Arrays and Ordered Maps
     3   *
     4   * Copyright 2021 Dapper Labs, Inc.
     5   *
     6   * Licensed under the Apache License, Version 2.0 (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   *   http://www.apache.org/licenses/LICENSE-2.0
    11   *
    12   * Unless required by applicable law or agreed to in writing, software
    13   * distributed under the License is distributed on an "AS IS" BASIS,
    14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15   * See the License for the specific language governing permissions and
    16   * limitations under the License.
    17   */
    18  
    19  package atree
    20  
    21  import (
    22  	"flag"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/fxamacker/cbor/v2"
    28  	"github.com/stretchr/testify/require"
    29  )
    30  
    31  var (
    32  	runes = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
    33  )
    34  
    35  var seed = flag.Int64("seed", 0, "seed for pseudo-random source")
    36  
    37  func newRand(tb testing.TB) *rand.Rand {
    38  	if *seed == 0 {
    39  		*seed = time.Now().UnixNano()
    40  	}
    41  
    42  	// Benchmarks always log, so only log for tests which
    43  	// will only log with -v flag or on error.
    44  	if t, ok := tb.(*testing.T); ok {
    45  		t.Logf("seed: %d\n", *seed)
    46  	}
    47  
    48  	return rand.New(rand.NewSource(*seed))
    49  }
    50  
    51  // randStr returns random UTF-8 string of given length.
    52  func randStr(r *rand.Rand, length int) string {
    53  	b := make([]rune, length)
    54  	for i := 0; i < length; i++ {
    55  		b[i] = runes[r.Intn(len(runes))]
    56  	}
    57  	return string(b)
    58  }
    59  
    60  func randomValue(r *rand.Rand, maxInlineSize int) Value {
    61  	switch r.Intn(6) {
    62  
    63  	case 0:
    64  		return Uint8Value(r.Intn(255))
    65  
    66  	case 1:
    67  		return Uint16Value(r.Intn(6535))
    68  
    69  	case 2:
    70  		return Uint32Value(r.Intn(4294967295))
    71  
    72  	case 3:
    73  		return Uint64Value(r.Intn(1844674407370955161))
    74  
    75  	case 4: // small string (inlinable)
    76  		slen := r.Intn(maxInlineSize)
    77  		return NewStringValue(randStr(r, slen))
    78  
    79  	case 5: // large string (external)
    80  		slen := r.Intn(1024) + maxInlineSize
    81  		return NewStringValue(randStr(r, slen))
    82  
    83  	default:
    84  		panic(NewUnreachableError())
    85  	}
    86  }
    87  
    88  type testTypeInfo struct {
    89  	value uint64
    90  }
    91  
    92  var _ TypeInfo = testTypeInfo{}
    93  
    94  func (i testTypeInfo) Encode(enc *cbor.StreamEncoder) error {
    95  	return enc.EncodeUint64(i.value)
    96  }
    97  
    98  func (i testTypeInfo) Equal(other TypeInfo) bool {
    99  	otherTestTypeInfo, ok := other.(testTypeInfo)
   100  	return ok && i.value == otherTestTypeInfo.value
   101  }
   102  
   103  func typeInfoComparator(a, b TypeInfo) bool {
   104  	x, ok := a.(testTypeInfo)
   105  	if !ok {
   106  		return false
   107  	}
   108  	y, ok := b.(testTypeInfo)
   109  	return ok && x.value == y.value
   110  }
   111  
   112  func newTestPersistentStorage(t testing.TB) *PersistentSlabStorage {
   113  	baseStorage := NewInMemBaseStorage()
   114  
   115  	encMode, err := cbor.EncOptions{}.EncMode()
   116  	require.NoError(t, err)
   117  
   118  	decMode, err := cbor.DecOptions{}.DecMode()
   119  	require.NoError(t, err)
   120  
   121  	return NewPersistentSlabStorage(
   122  		baseStorage,
   123  		encMode,
   124  		decMode,
   125  		decodeStorable,
   126  		decodeTypeInfo,
   127  	)
   128  }
   129  
   130  func newTestPersistentStorageWithData(t testing.TB, data map[StorageID][]byte) *PersistentSlabStorage {
   131  	baseStorage := NewInMemBaseStorage()
   132  	baseStorage.segments = data
   133  	return newTestPersistentStorageWithBaseStorage(t, baseStorage)
   134  }
   135  
   136  func newTestPersistentStorageWithBaseStorage(t testing.TB, baseStorage BaseStorage) *PersistentSlabStorage {
   137  
   138  	encMode, err := cbor.EncOptions{}.EncMode()
   139  	require.NoError(t, err)
   140  
   141  	decMode, err := cbor.DecOptions{}.DecMode()
   142  	require.NoError(t, err)
   143  
   144  	return NewPersistentSlabStorage(
   145  		baseStorage,
   146  		encMode,
   147  		decMode,
   148  		decodeStorable,
   149  		decodeTypeInfo,
   150  	)
   151  }
   152  
   153  func newTestBasicStorage(t testing.TB) *BasicSlabStorage {
   154  	encMode, err := cbor.EncOptions{}.EncMode()
   155  	require.NoError(t, err)
   156  
   157  	decMode, err := cbor.DecOptions{}.DecMode()
   158  	require.NoError(t, err)
   159  
   160  	return NewBasicSlabStorage(
   161  		encMode,
   162  		decMode,
   163  		decodeStorable,
   164  		decodeTypeInfo,
   165  	)
   166  }
   167  
   168  type InMemBaseStorage struct {
   169  	segments         map[StorageID][]byte
   170  	storageIndex     map[Address]StorageIndex
   171  	bytesRetrieved   int
   172  	bytesStored      int
   173  	segmentsReturned map[StorageID]struct{}
   174  	segmentsUpdated  map[StorageID]struct{}
   175  	segmentsTouched  map[StorageID]struct{}
   176  }
   177  
   178  var _ BaseStorage = &InMemBaseStorage{}
   179  
   180  func NewInMemBaseStorage() *InMemBaseStorage {
   181  	return NewInMemBaseStorageFromMap(
   182  		make(map[StorageID][]byte),
   183  	)
   184  }
   185  
   186  func NewInMemBaseStorageFromMap(segments map[StorageID][]byte) *InMemBaseStorage {
   187  	return &InMemBaseStorage{
   188  		segments:         segments,
   189  		storageIndex:     make(map[Address]StorageIndex),
   190  		segmentsReturned: make(map[StorageID]struct{}),
   191  		segmentsUpdated:  make(map[StorageID]struct{}),
   192  		segmentsTouched:  make(map[StorageID]struct{}),
   193  	}
   194  }
   195  
   196  func (s *InMemBaseStorage) Retrieve(id StorageID) ([]byte, bool, error) {
   197  	seg, ok := s.segments[id]
   198  	s.bytesRetrieved += len(seg)
   199  	s.segmentsReturned[id] = struct{}{}
   200  	s.segmentsTouched[id] = struct{}{}
   201  	return seg, ok, nil
   202  }
   203  
   204  func (s *InMemBaseStorage) Store(id StorageID, data []byte) error {
   205  	s.segments[id] = data
   206  	s.bytesStored += len(data)
   207  	s.segmentsUpdated[id] = struct{}{}
   208  	s.segmentsTouched[id] = struct{}{}
   209  	return nil
   210  }
   211  
   212  func (s *InMemBaseStorage) Remove(id StorageID) error {
   213  	s.segmentsUpdated[id] = struct{}{}
   214  	s.segmentsTouched[id] = struct{}{}
   215  	delete(s.segments, id)
   216  	return nil
   217  }
   218  
   219  func (s *InMemBaseStorage) GenerateStorageID(address Address) (StorageID, error) {
   220  	index := s.storageIndex[address]
   221  	nextIndex := index.Next()
   222  
   223  	s.storageIndex[address] = nextIndex
   224  	return NewStorageID(address, nextIndex), nil
   225  }
   226  
   227  func (s *InMemBaseStorage) SegmentCounts() int {
   228  	return len(s.segments)
   229  }
   230  
   231  func (s *InMemBaseStorage) Size() int {
   232  	total := 0
   233  	for _, seg := range s.segments {
   234  		total += len(seg)
   235  	}
   236  	return total
   237  }
   238  
   239  func (s *InMemBaseStorage) BytesRetrieved() int {
   240  	return s.bytesRetrieved
   241  }
   242  
   243  func (s *InMemBaseStorage) BytesStored() int {
   244  	return s.bytesStored
   245  }
   246  
   247  func (s *InMemBaseStorage) SegmentsReturned() int {
   248  	return len(s.segmentsReturned)
   249  }
   250  
   251  func (s *InMemBaseStorage) SegmentsUpdated() int {
   252  	return len(s.segmentsUpdated)
   253  }
   254  
   255  func (s *InMemBaseStorage) SegmentsTouched() int {
   256  	return len(s.segmentsTouched)
   257  }
   258  
   259  func (s *InMemBaseStorage) ResetReporter() {
   260  	s.bytesStored = 0
   261  	s.bytesRetrieved = 0
   262  	s.segmentsReturned = make(map[StorageID]struct{})
   263  	s.segmentsUpdated = make(map[StorageID]struct{})
   264  	s.segmentsTouched = make(map[StorageID]struct{})
   265  }
   266  
   267  func valueEqual(t *testing.T, tic TypeInfoComparator, a Value, b Value) {
   268  	switch a.(type) {
   269  	case *Array:
   270  		arrayEqual(t, tic, a, b)
   271  	case *OrderedMap:
   272  		mapEqual(t, tic, a, b)
   273  	default:
   274  		require.Equal(t, a, b)
   275  	}
   276  }
   277  
   278  func arrayEqual(t *testing.T, tic TypeInfoComparator, a Value, b Value) {
   279  	array1, ok := a.(*Array)
   280  	require.True(t, ok)
   281  
   282  	array2, ok := b.(*Array)
   283  	require.True(t, ok)
   284  
   285  	require.True(t, tic(array1.Type(), array2.Type()))
   286  	require.Equal(t, array1.Address(), array2.Address())
   287  	require.Equal(t, array1.Count(), array2.Count())
   288  	require.Equal(t, array1.StorageID(), array2.StorageID())
   289  
   290  	iterator1, err := array1.Iterator()
   291  	require.NoError(t, err)
   292  
   293  	iterator2, err := array2.Iterator()
   294  	require.NoError(t, err)
   295  
   296  	for {
   297  		value1, err := iterator1.Next()
   298  		require.NoError(t, err)
   299  
   300  		value2, err := iterator2.Next()
   301  		require.NoError(t, err)
   302  
   303  		valueEqual(t, tic, value1, value2)
   304  
   305  		if value1 == nil || value2 == nil {
   306  			break
   307  		}
   308  	}
   309  }
   310  
   311  func mapEqual(t *testing.T, tic TypeInfoComparator, a Value, b Value) {
   312  	m1, ok := a.(*OrderedMap)
   313  	require.True(t, ok)
   314  
   315  	m2, ok := b.(*OrderedMap)
   316  	require.True(t, ok)
   317  
   318  	require.True(t, tic(m1.Type(), m2.Type()))
   319  	require.Equal(t, m1.Address(), m2.Address())
   320  	require.Equal(t, m1.Count(), m2.Count())
   321  	require.Equal(t, m1.StorageID(), m2.StorageID())
   322  
   323  	iterator1, err := m1.Iterator()
   324  	require.NoError(t, err)
   325  
   326  	iterator2, err := m2.Iterator()
   327  	require.NoError(t, err)
   328  
   329  	for {
   330  		key1, value1, err := iterator1.Next()
   331  		require.NoError(t, err)
   332  
   333  		key2, value2, err := iterator2.Next()
   334  		require.NoError(t, err)
   335  
   336  		valueEqual(t, tic, key1, key2)
   337  		valueEqual(t, tic, value1, value2)
   338  
   339  		if key1 == nil || key2 == nil {
   340  			break
   341  		}
   342  	}
   343  }