github.com/protolambda/zssz@v0.1.5/types/ssz_container.go (about)

     1  package types
     2  
     3  import (
     4  	"fmt"
     5  	. "github.com/protolambda/zssz/dec"
     6  	. "github.com/protolambda/zssz/enc"
     7  	. "github.com/protolambda/zssz/htr"
     8  	"github.com/protolambda/zssz/merkle"
     9  	. "github.com/protolambda/zssz/pretty"
    10  	"github.com/protolambda/zssz/util/tags"
    11  	"reflect"
    12  	"unsafe"
    13  )
    14  
    15  const SSZ_TAG = "ssz"
    16  const OMIT_FLAG = "omit"
    17  const SQUASH_FLAG = "squash"
    18  
    19  type FieldPtrFn func(p unsafe.Pointer) unsafe.Pointer
    20  
    21  func (fn FieldPtrFn) WrapOffset(memOffset uintptr) FieldPtrFn {
    22  	return func(p unsafe.Pointer) unsafe.Pointer {
    23  		return fn(unsafe.Pointer(uintptr(p) + memOffset))
    24  	}
    25  }
    26  
    27  type ContainerField struct {
    28  	ssz      SSZ
    29  	name     string
    30  	pureName string
    31  	ptrFn    FieldPtrFn
    32  	isFixed  bool
    33  }
    34  
    35  func (c *ContainerField) Wrap(name string, memOffset uintptr) ContainerField {
    36  	return ContainerField{
    37  		ssz:      c.ssz,
    38  		name:     name + ">" + c.name,
    39  		pureName: c.name,
    40  		ptrFn:    c.ptrFn.WrapOffset(memOffset),
    41  		isFixed:  c.ssz.IsFixed(),
    42  	}
    43  }
    44  
    45  type SquashableFields interface {
    46  	// Get the ContainerFields
    47  	SquashFields() []ContainerField
    48  }
    49  
    50  func GetOffsetPtrFn(memOffset uintptr) FieldPtrFn {
    51  	return func(p unsafe.Pointer) unsafe.Pointer {
    52  		return unsafe.Pointer(uintptr(p) + memOffset)
    53  	}
    54  }
    55  
    56  type SSZContainer struct {
    57  	Fields      []ContainerField
    58  	isFixedLen  bool
    59  	fixedLen    uint64
    60  	minLen      uint64
    61  	maxLen      uint64
    62  	offsetCount uint64 // includes offsets for fields that are squashed in
    63  	fuzzMinLen  uint64
    64  	fuzzMaxLen  uint64
    65  }
    66  
    67  func (v *SSZContainer) SquashFields() []ContainerField {
    68  	return v.Fields
    69  }
    70  
    71  // Get the container fields for the given struct field
    72  // 0 fields (nil) if struct field is ignored
    73  // 1 field for normal struct fields
    74  // 0 or more fields when a struct field is squashed (recursively adding to the total field collection)
    75  func getFields(factory SSZFactoryFn, f *reflect.StructField) (out []ContainerField, err error) {
    76  	if tags.HasFlag(f, SSZ_TAG, OMIT_FLAG) {
    77  		return nil, nil
    78  	}
    79  	fieldSSZ, err := factory(f.Type)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	forceSquash := tags.HasFlag(f, SSZ_TAG, SQUASH_FLAG)
    85  
    86  	if f.Anonymous || forceSquash {
    87  		if squashable, ok := fieldSSZ.(SquashableFields); ok {
    88  			for _, sq := range squashable.SquashFields() {
    89  				out = append(out, sq.Wrap(f.Name, f.Offset))
    90  			}
    91  			return out, nil
    92  		}
    93  		// anonymous fields can be handled as normal fields. Only error when it was tagged to be squashed.
    94  		if forceSquash {
    95  			return nil, fmt.Errorf("could not squash field %s", f.Name)
    96  		}
    97  	}
    98  
    99  	out = append(out, ContainerField{
   100  		ssz: fieldSSZ, pureName: f.Name, name: f.Name,
   101  		ptrFn: GetOffsetPtrFn(f.Offset), isFixed: fieldSSZ.IsFixed()})
   102  	return
   103  }
   104  
   105  func NewSSZContainer(factory SSZFactoryFn, typ reflect.Type) (*SSZContainer, error) {
   106  	if typ.Kind() != reflect.Struct {
   107  		return nil, fmt.Errorf("typ is not a struct")
   108  	}
   109  	res := new(SSZContainer)
   110  	for i, c := 0, typ.NumField(); i < c; i++ {
   111  		// get the Go struct field
   112  		sField := typ.Field(i)
   113  		// For this field, get the SSZ field(s). There may be more if the Go field is squashed.
   114  		fields, err := getFields(factory, &sField)
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  		res.Fields = append(res.Fields, fields...)
   119  	}
   120  	for _, field := range res.Fields {
   121  		if field.ssz.IsFixed() {
   122  			fixed, min, max := field.ssz.FixedLen(), field.ssz.MinLen(), field.ssz.MaxLen()
   123  			if fixed != min || fixed != max {
   124  				return nil, fmt.Errorf("fixed-size field ('%s') in struct has invalid min/max length", field.name)
   125  			}
   126  			res.fixedLen += fixed
   127  			res.minLen += fixed
   128  			res.maxLen += fixed
   129  		} else {
   130  			res.fixedLen += BYTES_PER_LENGTH_OFFSET
   131  			res.minLen += BYTES_PER_LENGTH_OFFSET + field.ssz.MinLen()
   132  			res.maxLen += BYTES_PER_LENGTH_OFFSET + field.ssz.MaxLen()
   133  			res.offsetCount++
   134  		}
   135  		res.fuzzMinLen += field.ssz.FuzzMinLen()
   136  		res.fuzzMaxLen += field.ssz.FuzzMaxLen()
   137  	}
   138  	res.isFixedLen = res.offsetCount == 0
   139  	return res, nil
   140  }
   141  
   142  func (v *SSZContainer) FuzzMinLen() uint64 {
   143  	return v.fuzzMinLen
   144  }
   145  
   146  func (v *SSZContainer) FuzzMaxLen() uint64 {
   147  	return v.fuzzMaxLen
   148  }
   149  
   150  func (v *SSZContainer) MinLen() uint64 {
   151  	return v.minLen
   152  }
   153  
   154  func (v *SSZContainer) MaxLen() uint64 {
   155  	return v.maxLen
   156  }
   157  
   158  func (v *SSZContainer) FixedLen() uint64 {
   159  	return v.fixedLen
   160  }
   161  
   162  func (v *SSZContainer) IsFixed() bool {
   163  	return v.isFixedLen
   164  }
   165  
   166  func (v *SSZContainer) SizeOf(p unsafe.Pointer) uint64 {
   167  	out := v.fixedLen
   168  	for _, f := range v.Fields {
   169  		if !f.ssz.IsFixed() {
   170  			out += f.ssz.SizeOf(f.ptrFn(p))
   171  		}
   172  	}
   173  	return out
   174  }
   175  
   176  func (v *SSZContainer) Encode(eb *EncodingWriter, p unsafe.Pointer) error {
   177  	// hot-path for common case of fixed-size container
   178  	if v.isFixedLen {
   179  		for i := range v.Fields {
   180  			f := &v.Fields[i]
   181  			if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil {
   182  				return err
   183  			}
   184  		}
   185  		return nil
   186  	}
   187  	// the previous offset, to calculate a new offset from, starting after the fixed data.
   188  	prevOffset := v.fixedLen
   189  	// span of the previous var-size element
   190  	prevSize := uint64(0)
   191  	for i := range v.Fields {
   192  		f := &v.Fields[i]
   193  		if f.isFixed {
   194  			if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil {
   195  				return err
   196  			}
   197  		} else {
   198  			if offset, err := eb.WriteOffset(prevOffset, prevSize); err != nil {
   199  				return err
   200  			} else {
   201  				prevOffset = offset
   202  			}
   203  			prevSize = f.ssz.SizeOf(f.ptrFn(p))
   204  		}
   205  	}
   206  	// Only iterate over and write dynamic parts if we need to.
   207  	if !v.isFixedLen {
   208  		for i := range v.Fields {
   209  			f := &v.Fields[i]
   210  			if !f.isFixed {
   211  				if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil {
   212  					return err
   213  				}
   214  			}
   215  		}
   216  	}
   217  	return nil
   218  }
   219  
   220  func (v *SSZContainer) decodeVarSizeFuzzmode(dr *DecodingReader, p unsafe.Pointer) error {
   221  	lengthLeftOver := v.fuzzMinLen
   222  
   223  	for _, f := range v.Fields {
   224  		lengthLeftOver -= f.ssz.FuzzMinLen()
   225  		span := dr.GetBytesSpan()
   226  		if span < lengthLeftOver {
   227  			return fmt.Errorf("under estimated length requirements for fuzzing input, not enough data available to fuzz")
   228  		}
   229  		available := span - lengthLeftOver
   230  
   231  		scoped, err := dr.Scope(available)
   232  		if err != nil {
   233  			return err
   234  		}
   235  		scoped.EnableFuzzMode()
   236  		if err := f.ssz.Decode(scoped, f.ptrFn(p)); err != nil {
   237  			return err
   238  		}
   239  		dr.UpdateIndexFromScoped(scoped)
   240  	}
   241  	return nil
   242  }
   243  
   244  func (v *SSZContainer) decodeDynamicPart(dr *DecodingReader, offsets []uint64, fieldHandler func(dr *DecodingReader, f *ContainerField) error) error {
   245  	i := 0
   246  	for fi := range v.Fields {
   247  		f := &v.Fields[fi]
   248  		// ignore fixed-size fields
   249  		if f.ssz.IsFixed() {
   250  			continue
   251  		}
   252  		// calculate the scope based on next offset, and max. value of this scope for the last value
   253  		var scope uint64
   254  		{
   255  			currentOffset := offsets[i]
   256  			if next := i + 1; next < len(offsets) {
   257  				if nextOffset := offsets[next]; nextOffset >= currentOffset {
   258  					scope = nextOffset - currentOffset
   259  				} else {
   260  					return fmt.Errorf("offset %d for field %s is invalid", i, f.name)
   261  				}
   262  			} else {
   263  				scope = dr.Max() - currentOffset
   264  			}
   265  		}
   266  		{
   267  			realOffset := dr.Index()
   268  			if expectedOffset := offsets[i]; expectedOffset != realOffset {
   269  				return fmt.Errorf("expected to be at %d bytes, but currently at %d", expectedOffset, realOffset)
   270  			}
   271  			scoped, err := dr.Scope(scope)
   272  			if err != nil {
   273  				return err
   274  			}
   275  			if err := fieldHandler(scoped, f); err != nil {
   276  				return err
   277  			}
   278  			dr.UpdateIndexFromScoped(scoped)
   279  		}
   280  		// go to next offset
   281  		i++
   282  	}
   283  	return nil
   284  }
   285  
   286  func (v *SSZContainer) processFixedPart(dr *DecodingReader, fieldHandler func(f *ContainerField) error) ([]uint64, error) {
   287  	// technically we could also ignore offset correctness and skip ahead,
   288  	//  but we may want to enforce proper offsets.
   289  	offsets := make([]uint64, 0, v.offsetCount)
   290  	startIndex := dr.Index()
   291  	fixedI := dr.Index()
   292  	for fi := range v.Fields {
   293  		f := &v.Fields[fi]
   294  		if f.ssz.IsFixed() {
   295  			fixedI += f.ssz.FixedLen()
   296  			// No need to redefine the scope for fixed-length SSZ objects.
   297  			if err := fieldHandler(f); err != nil {
   298  				return nil, err
   299  			}
   300  		} else {
   301  			fixedI += BYTES_PER_LENGTH_OFFSET
   302  			// write an offset to the fixed data, to find the dynamic data with as a reader
   303  			offset, err := dr.ReadOffset()
   304  			if err != nil {
   305  				return nil, err
   306  			}
   307  			offsets = append(offsets, offset)
   308  		}
   309  		if i := dr.Index(); i != fixedI {
   310  			return nil, fmt.Errorf("fixed part had different size than expected, now at %d, expected to be at %d", i, fixedI)
   311  		}
   312  	}
   313  	pivotIndex := dr.Index()
   314  	if expectedIndex := v.fixedLen + startIndex; pivotIndex != expectedIndex {
   315  		return nil, fmt.Errorf("expected to read to %d bytes for fixed part of container, got to %d", expectedIndex, pivotIndex)
   316  	}
   317  	return offsets, nil
   318  }
   319  
   320  func (v *SSZContainer) decodeVarSize(dr *DecodingReader, p unsafe.Pointer) error {
   321  	offsets, err := v.processFixedPart(dr, func(f *ContainerField) error {
   322  		return f.ssz.Decode(dr, f.ptrFn(p))
   323  	})
   324  	if err != nil {
   325  		return err
   326  	}
   327  	return v.decodeDynamicPart(dr, offsets, func(scopedDr *DecodingReader, f *ContainerField) error {
   328  		return f.ssz.Decode(scopedDr, f.ptrFn(p))
   329  	})
   330  }
   331  
   332  func (v *SSZContainer) Decode(dr *DecodingReader, p unsafe.Pointer) error {
   333  	if dr.IsFuzzMode() {
   334  		return v.decodeVarSizeFuzzmode(dr, p)
   335  	} else {
   336  		return v.decodeVarSize(dr, p)
   337  	}
   338  }
   339  
   340  func (v *SSZContainer) DryCheck(dr *DecodingReader) error {
   341  	offsets, err := v.processFixedPart(dr, func(f *ContainerField) error {
   342  		return f.ssz.DryCheck(dr)
   343  	})
   344  	if err != nil {
   345  		return err
   346  	}
   347  	return v.decodeDynamicPart(dr, offsets, func(scopedDr *DecodingReader, f *ContainerField) error {
   348  		return f.ssz.DryCheck(scopedDr)
   349  	})
   350  }
   351  
   352  func (v *SSZContainer) HashTreeRoot(h MerkleFn, p unsafe.Pointer) [32]byte {
   353  	leaf := func(i uint64) []byte {
   354  		f := v.Fields[i]
   355  		r := f.ssz.HashTreeRoot(h, f.ptrFn(p))
   356  		return r[:]
   357  	}
   358  	leafCount := uint64(len(v.Fields))
   359  	return merkle.Merkleize(h, leafCount, leafCount, leaf)
   360  }
   361  
   362  func (v *SSZContainer) SigningRoot(h MerkleFn, p unsafe.Pointer) [32]byte {
   363  	leaf := func(i uint64) []byte {
   364  		f := v.Fields[i]
   365  		r := f.ssz.HashTreeRoot(h, f.ptrFn(p))
   366  		return r[:]
   367  	}
   368  	// truncate last field
   369  	leafCount := uint64(len(v.Fields))
   370  	if leafCount != 0 {
   371  		leafCount--
   372  	}
   373  	return merkle.Merkleize(h, leafCount, leafCount, leaf)
   374  }
   375  
   376  func (v *SSZContainer) Pretty(indent uint32, w *PrettyWriter, p unsafe.Pointer) {
   377  	w.WriteIndent(indent)
   378  	w.Write("{\n")
   379  	for i, f := range v.Fields {
   380  		w.WriteIndent(indent + 1)
   381  		w.Write(f.pureName)
   382  		w.Write(":\n")
   383  		f.ssz.Pretty(indent+3, w, f.ptrFn(p))
   384  		if i == len(v.Fields)-1 {
   385  			w.Write("\n")
   386  		} else {
   387  			w.Write(",\n")
   388  		}
   389  	}
   390  	w.WriteIndent(indent)
   391  	w.Write("}")
   392  }