github.com/onflow/atree@v0.6.0/storable_test.go (about)

     1  /*
     2   * Atree - Scalable Arrays and Ordered Maps
     3   *
     4   * Copyright 2021 Dapper Labs, Inc.
     5   *
     6   * Licensed under the Apache License, Version 2.0 (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   *   http://www.apache.org/licenses/LICENSE-2.0
    11   *
    12   * Unless required by applicable law or agreed to in writing, software
    13   * distributed under the License is distributed on an "AS IS" BASIS,
    14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15   * See the License for the specific language governing permissions and
    16   * limitations under the License.
    17   */
    18  
    19  package atree
    20  
    21  import (
    22  	"encoding/binary"
    23  	"fmt"
    24  	"math"
    25  
    26  	"github.com/fxamacker/cbor/v2"
    27  )
    28  
    29  // This file contains value implementations for testing purposes
    30  
    31  const (
    32  	cborTagUInt8Value  = 161
    33  	cborTagUInt16Value = 162
    34  	cborTagUInt32Value = 163
    35  	cborTagUInt64Value = 164
    36  	cborTagSomeValue   = 165
    37  )
    38  
    39  type HashableValue interface {
    40  	Value
    41  	HashInput(scratch []byte) ([]byte, error)
    42  }
    43  
    44  type Uint8Value uint8
    45  
    46  var _ Value = Uint8Value(0)
    47  var _ Storable = Uint8Value(0)
    48  var _ HashableValue = Uint8Value(0)
    49  
    50  func (v Uint8Value) ChildStorables() []Storable { return nil }
    51  
    52  func (v Uint8Value) StoredValue(_ SlabStorage) (Value, error) {
    53  	return v, nil
    54  }
    55  
    56  func (v Uint8Value) Storable(_ SlabStorage, _ Address, _ uint64) (Storable, error) {
    57  	return v, nil
    58  }
    59  
    60  // Encode encodes UInt8Value as
    61  //
    62  //	cbor.Tag{
    63  //			Number:  cborTagUInt8Value,
    64  //			Content: uint8(v),
    65  //	}
    66  func (v Uint8Value) Encode(enc *Encoder) error {
    67  	err := enc.CBOR.EncodeRawBytes([]byte{
    68  		// tag number
    69  		0xd8, cborTagUInt8Value,
    70  	})
    71  	if err != nil {
    72  		return err
    73  	}
    74  	return enc.CBOR.EncodeUint8(uint8(v))
    75  }
    76  
    77  func (v Uint8Value) HashInput(scratch []byte) ([]byte, error) {
    78  
    79  	const cborTypePositiveInt = 0x00
    80  
    81  	buf := scratch
    82  	if len(scratch) < 4 {
    83  		buf = make([]byte, 4)
    84  	}
    85  
    86  	buf[0], buf[1] = 0xd8, cborTagUInt8Value // Tag number
    87  
    88  	if v <= 23 {
    89  		buf[2] = cborTypePositiveInt | byte(v)
    90  		return buf[:3], nil
    91  	}
    92  
    93  	buf[2] = cborTypePositiveInt | byte(24)
    94  	buf[3] = byte(v)
    95  	return buf[:4], nil
    96  }
    97  
    98  // TODO: cache size
    99  func (v Uint8Value) ByteSize() uint32 {
   100  	// tag number (2 bytes) + encoded content
   101  	return 2 + GetUintCBORSize(uint64(v))
   102  }
   103  
   104  func (v Uint8Value) String() string {
   105  	return fmt.Sprintf("%d", uint8(v))
   106  }
   107  
   108  type Uint16Value uint16
   109  
   110  var _ Value = Uint16Value(0)
   111  var _ Storable = Uint16Value(0)
   112  var _ HashableValue = Uint16Value(0)
   113  
   114  func (v Uint16Value) ChildStorables() []Storable { return nil }
   115  
   116  func (v Uint16Value) StoredValue(_ SlabStorage) (Value, error) {
   117  	return v, nil
   118  }
   119  
   120  func (v Uint16Value) Storable(_ SlabStorage, _ Address, _ uint64) (Storable, error) {
   121  	return v, nil
   122  }
   123  
   124  func (v Uint16Value) Encode(enc *Encoder) error {
   125  	err := enc.CBOR.EncodeRawBytes([]byte{
   126  		// tag number
   127  		0xd8, cborTagUInt16Value,
   128  	})
   129  	if err != nil {
   130  		return err
   131  	}
   132  	return enc.CBOR.EncodeUint16(uint16(v))
   133  }
   134  
   135  func (v Uint16Value) HashInput(scratch []byte) ([]byte, error) {
   136  	const cborTypePositiveInt = 0x00
   137  
   138  	buf := scratch
   139  	if len(buf) < 8 {
   140  		buf = make([]byte, 8)
   141  	}
   142  
   143  	buf[0], buf[1] = 0xd8, cborTagUInt16Value // Tag number
   144  
   145  	if v <= 23 {
   146  		buf[2] = cborTypePositiveInt | byte(v)
   147  		return buf[:3], nil
   148  	}
   149  
   150  	if v <= math.MaxUint8 {
   151  		buf[2] = cborTypePositiveInt | byte(24)
   152  		buf[3] = byte(v)
   153  		return buf[:4], nil
   154  	}
   155  
   156  	buf[2] = cborTypePositiveInt | byte(25)
   157  	binary.BigEndian.PutUint16(buf[3:], uint16(v))
   158  	return buf[:5], nil
   159  }
   160  
   161  // TODO: cache size
   162  func (v Uint16Value) ByteSize() uint32 {
   163  	// tag number (2 bytes) + encoded content
   164  	return 2 + GetUintCBORSize(uint64(v))
   165  }
   166  
   167  func (v Uint16Value) String() string {
   168  	return fmt.Sprintf("%d", uint16(v))
   169  }
   170  
   171  type Uint32Value uint32
   172  
   173  var _ Value = Uint32Value(0)
   174  var _ Storable = Uint32Value(0)
   175  var _ HashableValue = Uint32Value(0)
   176  
   177  func (v Uint32Value) ChildStorables() []Storable { return nil }
   178  
   179  func (v Uint32Value) StoredValue(_ SlabStorage) (Value, error) {
   180  	return v, nil
   181  }
   182  
   183  func (v Uint32Value) Storable(_ SlabStorage, _ Address, _ uint64) (Storable, error) {
   184  	return v, nil
   185  }
   186  
   187  // Encode encodes UInt32Value as
   188  //
   189  //	cbor.Tag{
   190  //			Number:  cborTagUInt32Value,
   191  //			Content: uint32(v),
   192  //	}
   193  func (v Uint32Value) Encode(enc *Encoder) error {
   194  	err := enc.CBOR.EncodeRawBytes([]byte{
   195  		// tag number
   196  		0xd8, cborTagUInt32Value,
   197  	})
   198  	if err != nil {
   199  		return err
   200  	}
   201  	return enc.CBOR.EncodeUint32(uint32(v))
   202  }
   203  
   204  func (v Uint32Value) HashInput(scratch []byte) ([]byte, error) {
   205  
   206  	const cborTypePositiveInt = 0x00
   207  
   208  	buf := scratch
   209  	if len(buf) < 8 {
   210  		buf = make([]byte, 8)
   211  	}
   212  
   213  	buf[0], buf[1] = 0xd8, cborTagUInt32Value // Tag number
   214  
   215  	if v <= 23 {
   216  		buf[2] = cborTypePositiveInt | byte(v)
   217  		return buf[:3], nil
   218  	}
   219  
   220  	if v <= math.MaxUint8 {
   221  		buf[2] = cborTypePositiveInt | byte(24)
   222  		buf[3] = byte(v)
   223  		return buf[:4], nil
   224  	}
   225  
   226  	if v <= math.MaxUint16 {
   227  		buf[2] = cborTypePositiveInt | byte(25)
   228  		binary.BigEndian.PutUint16(buf[3:], uint16(v))
   229  		return buf[:5], nil
   230  	}
   231  
   232  	buf[2] = cborTypePositiveInt | byte(26)
   233  	binary.BigEndian.PutUint32(buf[3:], uint32(v))
   234  	return buf[:7], nil
   235  }
   236  
   237  // TODO: cache size
   238  func (v Uint32Value) ByteSize() uint32 {
   239  	// tag number (2 bytes) + encoded content
   240  	return 2 + GetUintCBORSize(uint64(v))
   241  }
   242  
   243  func (v Uint32Value) String() string {
   244  	return fmt.Sprintf("%d", uint32(v))
   245  }
   246  
   247  type Uint64Value uint64
   248  
   249  var _ Value = Uint64Value(0)
   250  var _ Storable = Uint64Value(0)
   251  var _ HashableValue = Uint64Value(0)
   252  
   253  func (v Uint64Value) ChildStorables() []Storable { return nil }
   254  
   255  func (v Uint64Value) StoredValue(_ SlabStorage) (Value, error) {
   256  	return v, nil
   257  }
   258  
   259  func (v Uint64Value) Storable(_ SlabStorage, _ Address, _ uint64) (Storable, error) {
   260  	return v, nil
   261  }
   262  
   263  // Encode encodes UInt64Value as
   264  //
   265  //	cbor.Tag{
   266  //			Number:  cborTagUInt64Value,
   267  //			Content: uint64(v),
   268  //	}
   269  func (v Uint64Value) Encode(enc *Encoder) error {
   270  	err := enc.CBOR.EncodeRawBytes([]byte{
   271  		// tag number
   272  		0xd8, cborTagUInt64Value,
   273  	})
   274  	if err != nil {
   275  		return err
   276  	}
   277  	return enc.CBOR.EncodeUint64(uint64(v))
   278  }
   279  
   280  func (v Uint64Value) HashInput(scratch []byte) ([]byte, error) {
   281  	const cborTypePositiveInt = 0x00
   282  
   283  	buf := scratch
   284  	if len(buf) < 16 {
   285  		buf = make([]byte, 16)
   286  	}
   287  
   288  	buf[0], buf[1] = 0xd8, cborTagUInt64Value // Tag number
   289  
   290  	if v <= 23 {
   291  		buf[2] = cborTypePositiveInt | byte(v)
   292  		return buf[:3], nil
   293  	}
   294  
   295  	if v <= math.MaxUint8 {
   296  		buf[2] = cborTypePositiveInt | byte(24)
   297  		buf[3] = byte(v)
   298  		return buf[:4], nil
   299  	}
   300  
   301  	if v <= math.MaxUint16 {
   302  		buf[2] = cborTypePositiveInt | byte(25)
   303  		binary.BigEndian.PutUint16(buf[3:], uint16(v))
   304  		return buf[:5], nil
   305  	}
   306  
   307  	if v <= math.MaxUint32 {
   308  		buf[2] = cborTypePositiveInt | byte(26)
   309  		binary.BigEndian.PutUint32(buf[3:], uint32(v))
   310  		return buf[:7], nil
   311  	}
   312  
   313  	buf[2] = cborTypePositiveInt | byte(27)
   314  	binary.BigEndian.PutUint64(buf[3:], uint64(v))
   315  	return buf[:11], nil
   316  }
   317  
   318  // TODO: cache size
   319  func (v Uint64Value) ByteSize() uint32 {
   320  	// tag number (2 bytes) + encoded content
   321  	return 2 + GetUintCBORSize(uint64(v))
   322  }
   323  
   324  func (v Uint64Value) String() string {
   325  	return fmt.Sprintf("%d", uint64(v))
   326  }
   327  
   328  type StringValue struct {
   329  	str  string
   330  	size uint32
   331  }
   332  
   333  var _ Value = StringValue{}
   334  var _ Storable = StringValue{}
   335  var _ HashableValue = StringValue{}
   336  
   337  func NewStringValue(s string) StringValue {
   338  	size := GetUintCBORSize(uint64(len(s))) + uint32(len(s))
   339  	return StringValue{str: s, size: size}
   340  }
   341  
   342  func (v StringValue) ChildStorables() []Storable { return nil }
   343  
   344  func (v StringValue) StoredValue(_ SlabStorage) (Value, error) {
   345  	return v, nil
   346  }
   347  
   348  func (v StringValue) Storable(storage SlabStorage, address Address, maxInlineSize uint64) (Storable, error) {
   349  	if uint64(v.ByteSize()) > maxInlineSize {
   350  
   351  		// Create StorableSlab
   352  		id, err := storage.GenerateStorageID(address)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  
   357  		slab := &StorableSlab{
   358  			StorageID: id,
   359  			Storable:  v,
   360  		}
   361  
   362  		// Store StorableSlab in storage
   363  		err = storage.Store(id, slab)
   364  		if err != nil {
   365  			return nil, err
   366  		}
   367  
   368  		// Return storage id as storable
   369  		return StorageIDStorable(id), nil
   370  	}
   371  
   372  	return v, nil
   373  }
   374  
   375  func (v StringValue) Encode(enc *Encoder) error {
   376  	return enc.CBOR.EncodeString(v.str)
   377  }
   378  
   379  func (v StringValue) HashInput(scratch []byte) ([]byte, error) {
   380  
   381  	const cborTypeTextString = 0x60
   382  
   383  	buf := scratch
   384  	if uint32(len(buf)) < v.size {
   385  		buf = make([]byte, v.size)
   386  	} else {
   387  		buf = buf[:v.size]
   388  	}
   389  
   390  	slen := len(v.str)
   391  
   392  	if slen <= 23 {
   393  		buf[0] = cborTypeTextString | byte(slen)
   394  		copy(buf[1:], v.str)
   395  		return buf, nil
   396  	}
   397  
   398  	if slen <= math.MaxUint8 {
   399  		buf[0] = cborTypeTextString | byte(24)
   400  		buf[1] = byte(slen)
   401  		copy(buf[2:], v.str)
   402  		return buf, nil
   403  	}
   404  
   405  	if slen <= math.MaxUint16 {
   406  		buf[0] = cborTypeTextString | byte(25)
   407  		binary.BigEndian.PutUint16(buf[1:], uint16(slen))
   408  		copy(buf[3:], v.str)
   409  		return buf, nil
   410  	}
   411  
   412  	if slen <= math.MaxUint32 {
   413  		buf[0] = cborTypeTextString | byte(26)
   414  		binary.BigEndian.PutUint32(buf[1:], uint32(slen))
   415  		copy(buf[5:], v.str)
   416  		return buf, nil
   417  	}
   418  
   419  	buf[0] = cborTypeTextString | byte(27)
   420  	binary.BigEndian.PutUint64(buf[1:], uint64(slen))
   421  	copy(buf[9:], v.str)
   422  	return buf, nil
   423  }
   424  
   425  func (v StringValue) ByteSize() uint32 {
   426  	return v.size
   427  }
   428  
   429  func (v StringValue) String() string {
   430  	return v.str
   431  }
   432  
   433  func decodeStorable(dec *cbor.StreamDecoder, id StorageID) (Storable, error) {
   434  	t, err := dec.NextType()
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  
   439  	switch t {
   440  	case cbor.TextStringType:
   441  		s, err := dec.DecodeString()
   442  		if err != nil {
   443  			return nil, err
   444  		}
   445  		return NewStringValue(s), nil
   446  
   447  	case cbor.TagType:
   448  		tagNumber, err := dec.DecodeTagNumber()
   449  		if err != nil {
   450  			return nil, err
   451  		}
   452  
   453  		switch tagNumber {
   454  		case CBORTagStorageID:
   455  			return DecodeStorageIDStorable(dec)
   456  
   457  		case cborTagUInt8Value:
   458  			n, err := dec.DecodeUint64()
   459  			if err != nil {
   460  				return nil, err
   461  			}
   462  			if n > math.MaxUint8 {
   463  				return nil, fmt.Errorf("invalid data, got %d, expected max %d", n, math.MaxUint8)
   464  			}
   465  			return Uint8Value(n), nil
   466  
   467  		case cborTagUInt16Value:
   468  			n, err := dec.DecodeUint64()
   469  			if err != nil {
   470  				return nil, err
   471  			}
   472  			if n > math.MaxUint16 {
   473  				return nil, fmt.Errorf("invalid data, got %d, expected max %d", n, math.MaxUint16)
   474  			}
   475  			return Uint16Value(n), nil
   476  
   477  		case cborTagUInt32Value:
   478  			n, err := dec.DecodeUint64()
   479  			if err != nil {
   480  				return nil, err
   481  			}
   482  			if n > math.MaxUint32 {
   483  				return nil, fmt.Errorf("invalid data, got %d, expected max %d", n, math.MaxUint32)
   484  			}
   485  			return Uint32Value(n), nil
   486  
   487  		case cborTagUInt64Value:
   488  			n, err := dec.DecodeUint64()
   489  			if err != nil {
   490  				return nil, err
   491  			}
   492  			return Uint64Value(n), nil
   493  
   494  		case cborTagSomeValue:
   495  			storable, err := decodeStorable(dec, id)
   496  			if err != nil {
   497  				return nil, err
   498  			}
   499  			return SomeStorable{Storable: storable}, nil
   500  
   501  		default:
   502  			return nil, fmt.Errorf("invalid tag number %d", tagNumber)
   503  		}
   504  	default:
   505  		return nil, fmt.Errorf("invalid cbor type %s for storable", t)
   506  	}
   507  }
   508  
   509  func decodeTypeInfo(dec *cbor.StreamDecoder) (TypeInfo, error) {
   510  	value, err := dec.DecodeUint64()
   511  	if err != nil {
   512  		return nil, err
   513  	}
   514  
   515  	return testTypeInfo{value: value}, nil
   516  }
   517  
   518  func compare(storage SlabStorage, value Value, storable Storable) (bool, error) {
   519  	switch v := value.(type) {
   520  
   521  	case Uint8Value:
   522  		other, ok := storable.(Uint8Value)
   523  		if !ok {
   524  			return false, nil
   525  		}
   526  		return uint8(other) == uint8(v), nil
   527  
   528  	case Uint16Value:
   529  		other, ok := storable.(Uint16Value)
   530  		if !ok {
   531  			return false, nil
   532  		}
   533  		return uint16(other) == uint16(v), nil
   534  
   535  	case Uint32Value:
   536  		other, ok := storable.(Uint32Value)
   537  		if !ok {
   538  			return false, nil
   539  		}
   540  		return uint32(other) == uint32(v), nil
   541  
   542  	case Uint64Value:
   543  		other, ok := storable.(Uint64Value)
   544  		if !ok {
   545  			return false, nil
   546  		}
   547  		return uint64(other) == uint64(v), nil
   548  
   549  	case StringValue:
   550  		other, ok := storable.(StringValue)
   551  		if ok {
   552  			return other.str == v.str, nil
   553  		}
   554  
   555  		// Retrieve value from storage
   556  		otherValue, err := storable.StoredValue(storage)
   557  		if err != nil {
   558  			return false, err
   559  		}
   560  		other, ok = otherValue.(StringValue)
   561  		if ok {
   562  			return other.str == v.str, nil
   563  		}
   564  
   565  		return false, nil
   566  
   567  	case SomeValue:
   568  		other, ok := storable.(SomeStorable)
   569  		if !ok {
   570  			return false, nil
   571  		}
   572  
   573  		return compare(storage, v.Value, other.Storable)
   574  	}
   575  
   576  	return false, fmt.Errorf("value %T not supported for comparison", value)
   577  }
   578  
   579  func hashInputProvider(value Value, buffer []byte) ([]byte, error) {
   580  	if hashable, ok := value.(HashableValue); ok {
   581  		return hashable.HashInput(buffer)
   582  	}
   583  
   584  	return nil, fmt.Errorf("value %T doesn't implement HashableValue interface", value)
   585  }
   586  
   587  type SomeValue struct {
   588  	Value Value
   589  }
   590  
   591  var _ Value = SomeValue{}
   592  var _ HashableValue = SomeValue{}
   593  
   594  func (v SomeValue) Storable(storage SlabStorage, address Address, maxSize uint64) (Storable, error) {
   595  
   596  	valueStorable, err := v.Value.Storable(
   597  		storage,
   598  		address,
   599  		maxSize-2,
   600  	)
   601  	if err != nil {
   602  		return nil, err
   603  	}
   604  
   605  	return SomeStorable{
   606  		Storable: valueStorable,
   607  	}, nil
   608  }
   609  
   610  func (v SomeValue) HashInput(scratch []byte) ([]byte, error) {
   611  
   612  	wv, ok := v.Value.(HashableValue)
   613  	if !ok {
   614  		return nil, fmt.Errorf("failed to hash wrapped value: %s", v.Value)
   615  	}
   616  
   617  	b, err := wv.HashInput(scratch)
   618  	if err != nil {
   619  		return nil, err
   620  	}
   621  
   622  	hi := make([]byte, len(b)+2)
   623  	hi[0] = 0xd8
   624  	hi[1] = cborTagSomeValue
   625  	copy(hi[2:], b)
   626  
   627  	return hi, nil
   628  }
   629  
   630  func (v SomeValue) String() string {
   631  	return fmt.Sprintf("%s", v.Value)
   632  }
   633  
   634  type SomeStorable struct {
   635  	Storable Storable
   636  }
   637  
   638  var _ Storable = SomeStorable{}
   639  
   640  func (v SomeStorable) ByteSize() uint32 {
   641  	// tag number (2 bytes) + encoded content
   642  	return 2 + v.Storable.ByteSize()
   643  }
   644  
   645  func (v SomeStorable) Encode(enc *Encoder) error {
   646  	err := enc.CBOR.EncodeRawBytes([]byte{
   647  		// tag number
   648  		0xd8, cborTagSomeValue,
   649  	})
   650  	if err != nil {
   651  		return err
   652  	}
   653  	return v.Storable.Encode(enc)
   654  }
   655  
   656  func (v SomeStorable) ChildStorables() []Storable {
   657  	return []Storable{v.Storable}
   658  }
   659  
   660  func (v SomeStorable) StoredValue(storage SlabStorage) (Value, error) {
   661  	wv, err := v.Storable.StoredValue(storage)
   662  	if err != nil {
   663  		return nil, err
   664  	}
   665  
   666  	return SomeValue{wv}, nil
   667  }
   668  
   669  func (v SomeStorable) String() string {
   670  	return fmt.Sprintf("%s", v.Storable)
   671  }